From 0551070bfc7767fbf47e05150ae05a4e387645c7 Mon Sep 17 00:00:00 2001 From: chai3 <16031581+chai3@users.noreply.github.com> Date: Thu, 8 Jun 2023 15:42:15 +0900 Subject: [PATCH 01/83] fix: raise TypeError on wrong device type in Pipeline.to and Inference.to Fixes 1397 --- pyannote/audio/core/inference.py | 5 +++++ pyannote/audio/core/pipeline.py | 7 ++++++- pyannote/audio/pipelines/speaker_verification.py | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index c38fef0f9..3b9ee8058 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -170,6 +170,11 @@ def __init__( def to(self, device: torch.device): """Send internal model to `device`""" + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.model.to(device) if self.model.specifications.powerset and not self.skip_conversion: self._powerset.to(device) diff --git a/pyannote/audio/core/pipeline.py b/pyannote/audio/core/pipeline.py index a5b4b1bc5..f844d584f 100644 --- a/pyannote/audio/core/pipeline.py +++ b/pyannote/audio/core/pipeline.py @@ -324,9 +324,14 @@ def __call__(self, file: AudioFile, **kwargs): return self.apply(file, **kwargs) - def to(self, device): + def to(self, device: torch.device): """Send pipeline to `device`""" + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + for _, pipeline in self._pipelines.items(): if hasattr(pipeline, "to"): _ = pipeline.to(device) diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index 1a672d614..d99537cbc 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -80,6 +80,11 @@ def __init__( self.model_.to(self.device) def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.model_.to(device) self.device = device return self @@ -255,6 +260,11 @@ def __init__( ) def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.classifier_ = SpeechBrain_EncoderClassifier.from_hparams( source=self.embedding, savedir=f"{CACHE_DIR}/speechbrain", @@ -415,6 +425,11 @@ def __init__( self.model_.to(self.device) def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.model_.to(device) self.device = device return self From 30ddb0b6f641217fc00cc47fc34f71cb3c611d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 12 Jun 2023 12:21:09 +0200 Subject: [PATCH 02/83] feat(task): add support for multi-task models (#1374) BREAKING(model): get rid of (flaky) `Model.introspection` --- CHANGELOG.md | 10 +- pyannote/audio/core/inference.py | 325 +++++++++------ pyannote/audio/core/io.py | 18 +- pyannote/audio/core/model.py | 392 +++++------------- pyannote/audio/core/task.py | 51 ++- pyannote/audio/models/segmentation/PyanNet.py | 5 +- pyannote/audio/models/segmentation/debug.py | 12 +- .../pipelines/overlapped_speech_detection.py | 2 +- pyannote/audio/pipelines/resegmentation.py | 3 +- .../audio/pipelines/speaker_diarization.py | 2 +- .../audio/pipelines/speaker_verification.py | 51 +-- pyannote/audio/pipelines/utils/oracle.py | 2 +- pyannote/audio/tasks/embedding/mixins.py | 9 +- pyannote/audio/tasks/segmentation/mixins.py | 12 +- .../audio/tasks/segmentation/multilabel.py | 27 +- .../overlapped_speech_detection.py | 21 +- .../tasks/segmentation/speaker_diarization.py | 38 +- .../segmentation/voice_activity_detection.py | 21 +- pyannote/audio/utils/multi_task.py | 59 +++ pyannote/audio/utils/powerset.py | 22 +- pyannote/audio/utils/preview.py | 2 +- tests/inference_test.py | 9 +- tests/test_train.py | 44 +- tutorials/add_your_own_task.ipynb | 100 ++--- tutorials/overlapped_speech_detection.ipynb | 18 +- 25 files changed, 628 insertions(+), 627 deletions(-) create mode 100644 pyannote/audio/utils/multi_task.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d7e50dba7..bc51c0e9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,10 @@ ### Breaking changes - BREAKING(task): rename `Segmentation` task to `SpeakerDiarization` - - BREAKING(task): remove support for variable chunk duration + - BREAKING(task): remove support for variable chunk duration for segmentation tasks - BREAKING(pipeline): pipeline defaults to CPU (use `pipeline.to(device)`) - BREAKING(pipeline): remove `SpeakerSegmentation` pipeline (use `SpeakerDiarization` pipeline) - - BREAKING(pipeline): remove support `FINCHClustering` and `HiddenMarkovModelClustering` + - BREAKING(pipeline): remove support for `FINCHClustering` and `HiddenMarkovModelClustering` - BREAKING(pipeline): remove `segmentation_duration` parameter from `SpeakerDiarization` pipeline (defaults to `duration` of segmentation model) - BREAKING(setup): drop support for Python 3.7 - BREAKING(io): channels are now 0-indexed (used to be 1-indexed) @@ -26,9 +26,14 @@ * replace `Audio()` by `Audio(mono="downmix")`; * replace `Audio(mono=True)` by `Audio(mono="downmix")`; * replace `Audio(mono=False)` by `Audio()`. + - BREAKING(model): get rid of (flaky) `Model.introspection` + If, for some weird reason, you wrote some custom code based on that, + you should instead rely on `Model.example_output`. + ### Features and improvements + - feat(task): add support for multi-task models - feat(pipeline): send pipeline to device with `pipeline.to(device)` - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task @@ -44,6 +49,7 @@ - fix(pipeline): fix reproducibility issue with Ampere CUDA devices - fix(pipeline): fix support for IOBase audio - fix(pipeline): fix corner case with no speaker + - fix(train): prevent metadata preparation to happen twice - improve(task): shorten and improve structure of Tensorboard tags ### Dependencies diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 3b9ee8058..703aa06cb 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,19 +27,20 @@ import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pytorch_lightning.utilities.memory import is_oom_error from pyannote.audio.core.io import AudioFile -from pyannote.audio.core.model import Model +from pyannote.audio.core.model import Model, Specifications from pyannote.audio.core.task import Resolution +from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.permutation import mae_cost_func, permutate from pyannote.audio.utils.powerset import Powerset from pyannote.audio.utils.reproducibility import fix_reproducibility -TaskName = Union[Text, None] - class BaseInference: pass @@ -68,10 +69,10 @@ class Inference(BaseInference): skip_aggregation : bool, optional Do not aggregate outputs when using "sliding" window. Defaults to False. skip_conversion: bool, optional - In case `model` has been trained with `powerset` mode, its output is automatically + In case a task has been trained with `powerset` mode, output is automatically converted to `multi-label`, unless `skip_conversion` is set to True. batch_size : int, optional - Batch size. Larger values make inference faster. Defaults to 32. + Batch size. Larger values (should) make inference faster. Defaults to 32. device : torch.device, optional Device used for inference. Defaults to `model.device`. In case `device` and `model.device` are different, model is sent to device. @@ -94,6 +95,7 @@ def __init__( batch_size: int = 32, use_auth_token: Union[Text, None] = None, ): + # ~~~~ model ~~~~~ self.model = ( model @@ -106,50 +108,70 @@ def __init__( ) ) - if window not in ["sliding", "whole"]: - raise ValueError('`window` must be "sliding" or "whole".') - - specifications = self.model.specifications - if specifications.resolution == Resolution.FRAME and window == "whole": - warnings.warn( - 'Using "whole" `window` inference with a frame-based model might lead to bad results ' - 'and huge memory consumption: it is recommended to set `window` to "sliding".' - ) - - self.window = window - self.skip_aggregation = skip_aggregation - if device is None: device = self.model.device self.device = device - self.pre_aggregation_hook = pre_aggregation_hook - self.model.eval() self.model.to(self.device) - # chunk duration used during training specifications = self.model.specifications - training_duration = specifications.duration - if duration is None: - duration = training_duration - elif training_duration != duration: + # ~~~~ sliding window ~~~~~ + + if window not in ["sliding", "whole"]: + raise ValueError('`window` must be "sliding" or "whole".') + + if window == "whole" and any( + s.resolution == Resolution.FRAME for s in specifications + ): + warnings.warn( + 'Using "whole" `window` inference with a frame-based model might lead to bad results ' + 'and huge memory consumption: it is recommended to set `window` to "sliding".' + ) + self.window = window + + training_duration = next(iter(specifications)).duration + duration = duration or training_duration + if training_duration != duration: warnings.warn( f"Model was trained with {training_duration:g}s chunks, and you requested " f"{duration:g}s chunks for inference: this might lead to suboptimal results." ) self.duration = duration - self.warm_up = specifications.warm_up + # ~~~~ powerset to multilabel conversion ~~~~ + + self.skip_conversion = skip_conversion + + conversion = list() + for s in specifications: + if s.powerset and not skip_conversion: + c = Powerset(len(s.classes), s.powerset_max_classes) + else: + c = nn.Identity() + conversion.append(c.to(self.device)) + + if isinstance(specifications, Specifications): + self.conversion = conversion[0] + else: + self.conversion = nn.ModuleList(conversion) + + # ~~~~ overlap-add aggregation ~~~~~ + + self.skip_aggregation = skip_aggregation + self.pre_aggregation_hook = pre_aggregation_hook + + self.warm_up = next(iter(specifications)).warm_up # Use that many seconds on the left- and rightmost parts of each chunk # to warm up the model. While the model does process those left- and right-most # parts, only the remaining central part of each chunk is used for aggregating # scores during inference. # step between consecutive chunks - if step is None: - step = 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + step = step or ( + 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + ) if step > self.duration: raise ValueError( @@ -160,14 +182,8 @@ def __init__( self.step = step self.batch_size = batch_size - self.skip_conversion = skip_conversion - if specifications.powerset and not self.skip_conversion: - self._powerset = Powerset( - len(specifications.classes), specifications.powerset_max_classes - ) - self._powerset.to(self.device) - def to(self, device: torch.device): + def to(self, device: torch.device) -> "Inference": """Send internal model to `device`""" if not isinstance(device, torch.device): @@ -176,12 +192,11 @@ def to(self, device: torch.device): ) self.model.to(device) - if self.model.specifications.powerset and not self.skip_conversion: - self._powerset.to(device) + self.conversion.to(device) self.device = device return self - def infer(self, chunks: torch.Tensor) -> np.ndarray: + def infer(self, chunks: torch.Tensor) -> Union[np.ndarray, Tuple[np.ndarray]]: """Forward pass Takes care of sending chunks to right device and outputs back to CPU @@ -193,11 +208,11 @@ def infer(self, chunks: torch.Tensor) -> np.ndarray: Returns ------- - outputs : (batch_size, ...) np.ndarray + outputs : (tuple of) (batch_size, ...) np.ndarray Model output. """ - with torch.no_grad(): + with torch.inference_mode(): try: outputs = self.model(chunks.to(self.device)) except RuntimeError as exception: @@ -209,22 +224,19 @@ def infer(self, chunks: torch.Tensor) -> np.ndarray: else: raise exception - # convert powerset to multi-label unless specifically requested not to - if self.model.specifications.powerset and not self.skip_conversion: - powerset = torch.nn.functional.one_hot( - torch.argmax(outputs, dim=-1), - self.model.specifications.num_powerset_classes, - ).float() - outputs = self._powerset.to_multilabel(powerset) + def __convert(output: torch.Tensor, conversion: nn.Module, **kwargs): + return conversion(output).cpu().numpy() - return outputs.cpu().numpy() + return map_with_specifications( + self.model.specifications, __convert, outputs, self.conversion + ) def slide( self, waveform: torch.Tensor, sample_rate: int, hook: Optional[Callable], - ) -> SlidingWindowFeature: + ) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]: """Slide model on a waveform Parameters @@ -241,23 +253,29 @@ def slide( Returns ------- - output : SlidingWindowFeature + output : (tuple of) SlidingWindowFeature Model output. Shape is (num_chunks, dimension) for chunk-level tasks, and (num_frames, dimension) for frame-level tasks. """ - window_size: int = round(self.duration * sample_rate) + window_size: int = self.model.audio.get_num_samples(self.duration) step_size: int = round(self.step * sample_rate) _, num_samples = waveform.shape - specifications = self.model.specifications - resolution = specifications.resolution - introspection = self.model.introspection - if resolution == Resolution.CHUNK: - frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) - elif resolution == Resolution.FRAME: - frames = introspection.frames - num_frames_per_chunk, dimension = introspection(window_size) + frames = self.model.example_output.frames + + def __frames( + frames, specifications: Optional[Specifications] = None + ) -> SlidingWindow: + if specifications.resolution == Resolution.CHUNK: + return SlidingWindow(start=0.0, duration=self.duration, step=self.step) + return frames + + frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( + self.model.specifications, + __frames, + self.model.example_output.frames, + ) # prepare complete chunks if num_samples >= window_size: @@ -274,75 +292,113 @@ def slide( num_samples - window_size ) % step_size > 0 if has_last_chunk: + # pad last chunk with zeros last_chunk: torch.Tensor = waveform[:, num_chunks * step_size :] + _, last_window_size = last_chunk.shape + last_pad = window_size - last_window_size + last_chunk = F.pad(last_chunk, (0, last_pad)) + + def __empty_list(**kwargs): + return list() - outputs: Union[List[np.ndarray], np.ndarray] = list() + outputs: Union[ + List[np.ndarray], Tuple[List[np.ndarray]] + ] = map_with_specifications(self.model.specifications, __empty_list) if hook is not None: hook(completed=0, total=num_chunks + has_last_chunk) + def __append_batch(output, batch_output, **kwargs) -> None: + output.append(batch_output) + return + # slide over audio chunks in batch for c in np.arange(0, num_chunks, self.batch_size): batch: torch.Tensor = chunks[c : c + self.batch_size] - outputs.append(self.infer(batch)) + + batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch) + + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, batch_outputs + ) + if hook is not None: hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk) # process orphan last chunk if has_last_chunk: + last_outputs = self.infer(last_chunk[None]) - last_output = self.infer(last_chunk[None]) - - if specifications.resolution == Resolution.FRAME: - pad = num_frames_per_chunk - last_output.shape[1] - last_output = np.pad(last_output, ((0, 0), (0, pad), (0, 0))) + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, last_outputs + ) - outputs.append(last_output) if hook is not None: hook( completed=num_chunks + has_last_chunk, total=num_chunks + has_last_chunk, ) - outputs = np.vstack(outputs) - - # skip aggregation when requested, - # or when model outputs just one vector per chunk - # or when model is permutation-invariant (and not post-processed) - if ( - self.skip_aggregation - or specifications.resolution == Resolution.CHUNK - or ( - specifications.permutation_invariant - and self.pre_aggregation_hook is None - ) - ): - frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) - return SlidingWindowFeature(outputs, frames) - - if self.pre_aggregation_hook is not None: - outputs = self.pre_aggregation_hook(outputs) - - aggregated = self.aggregate( - SlidingWindowFeature( - outputs, - SlidingWindow(start=0.0, duration=self.duration, step=self.step), - ), - frames=frames, - warm_up=self.warm_up, - hamming=True, - missing=0.0, + def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray: + return np.vstack(output) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications( + self.model.specifications, __vstack, outputs ) - if has_last_chunk: - num_frames = aggregated.data.shape[0] - aggregated.data = aggregated.data[: num_frames - pad, :] + def __aggregate( + outputs: np.ndarray, + frames: SlidingWindow, + specifications: Optional[Specifications] = None, + ) -> SlidingWindowFeature: + # skip aggregation when requested, + # or when model outputs just one vector per chunk + # or when model is permutation-invariant (and not post-processed) + if ( + self.skip_aggregation + or specifications.resolution == Resolution.CHUNK + or ( + specifications.permutation_invariant + and self.pre_aggregation_hook is None + ) + ): + frames = SlidingWindow( + start=0.0, duration=self.duration, step=self.step + ) + return SlidingWindowFeature(outputs, frames) + + if self.pre_aggregation_hook is not None: + outputs = self.pre_aggregation_hook(outputs) + + aggregated = self.aggregate( + SlidingWindowFeature( + outputs, + SlidingWindow(start=0.0, duration=self.duration, step=self.step), + ), + frames=frames, + warm_up=self.warm_up, + hamming=True, + missing=0.0, + ) + + # remove padding that was added to last chunk + if has_last_chunk: + aggregated.data = aggregated.crop( + Segment(0.0, num_samples / sample_rate), mode="loose" + ) - return aggregated + return aggregated + + return map_with_specifications( + self.model.specifications, __aggregate, outputs, frames + ) def __call__( self, file: AudioFile, hook: Optional[Callable] = None - ) -> Union[SlidingWindowFeature, np.ndarray]: + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: """Run inference on a whole file Parameters @@ -357,7 +413,7 @@ def __call__( Returns ------- - output : SlidingWindowFeature or np.ndarray + output : (tuple of) SlidingWindowFeature or np.ndarray Model output, as `SlidingWindowFeature` if `window` is set to "sliding" and `np.ndarray` if is set to "whole". @@ -370,7 +426,14 @@ def __call__( if self.window == "sliding": return self.slide(waveform, sample_rate, hook=hook) - return self.infer(waveform[None])[0] + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) def crop( self, @@ -378,7 +441,10 @@ def crop( chunk: Union[Segment, List[Segment]], duration: Optional[float] = None, hook: Optional[Callable] = None, - ) -> Union[SlidingWindowFeature, np.ndarray]: + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: """Run inference on a chunk or a list of chunks Parameters @@ -403,7 +469,7 @@ def crop( Returns ------- - output : SlidingWindowFeature or np.ndarray + output : (tuple of) SlidingWindowFeature or np.ndarray Model output, as `SlidingWindowFeature` if `window` is set to "sliding" and `np.ndarray` if is set to "whole". @@ -420,7 +486,6 @@ def crop( fix_reproducibility(self.device) if self.window == "sliding": - if not isinstance(chunk, Segment): start = min(c.start for c in chunk) end = max(c.end for c in chunk) @@ -429,32 +494,37 @@ def crop( waveform, sample_rate = self.model.audio.crop( file, chunk, duration=duration ) - output = self.slide(waveform, sample_rate, hook=hook) - - frames = output.sliding_window - shifted_frames = SlidingWindow( - start=chunk.start, duration=frames.duration, step=frames.step - ) - return SlidingWindowFeature(output.data, shifted_frames) - - elif self.window == "whole": - - if isinstance(chunk, Segment): - waveform, sample_rate = self.model.audio.crop( - file, chunk, duration=duration - ) - else: - waveform = torch.cat( - [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 + outputs: Union[ + SlidingWindowFeature, Tuple[SlidingWindowFeature] + ] = self.slide(waveform, sample_rate, hook=hook) + + def __shift(output: SlidingWindowFeature, **kwargs) -> SlidingWindowFeature: + frames = output.sliding_window + shifted_frames = SlidingWindow( + start=chunk.start, duration=frames.duration, step=frames.step ) + return SlidingWindowFeature(output.data, shifted_frames) - return self.infer(waveform[None])[0] + return map_with_specifications(self.model.specifications, __shift, outputs) + if isinstance(chunk, Segment): + waveform, sample_rate = self.model.audio.crop( + file, chunk, duration=duration + ) else: - raise NotImplementedError( - f"Unsupported window type '{self.window}': should be 'sliding' or 'whole'." + waveform = torch.cat( + [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 ) + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) + @staticmethod def aggregate( scores: SlidingWindowFeature, @@ -696,7 +766,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): stitches = [] for C, (chunk, activation) in enumerate(activations): - local_stitch = np.NAN * np.zeros( (sum(lookahead) + 1, num_frames, num_classes) ) @@ -704,7 +773,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): for c in range( max(0, C - lookahead[0]), min(num_chunks, C + lookahead[1] + 1) ): - # extract common temporal support shift = round((C - c) * num_frames * chunks.step / chunks.duration) @@ -725,7 +793,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): ) for this, that in enumerate(permutation): - # only stitch under certain condiditions matching = (c == C) or ( match_func( diff --git a/pyannote/audio/core/io.py b/pyannote/audio/core/io.py index b2e8842b1..0a44e75ea 100644 --- a/pyannote/audio/core/io.py +++ b/pyannote/audio/core/io.py @@ -150,7 +150,6 @@ def validate_file(file: AudioFile) -> Mapping: raise ValueError(AudioFileDocString) if "waveform" in file: - waveform: Union[np.ndarray, Tensor] = file["waveform"] if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError( @@ -166,7 +165,6 @@ def validate_file(file: AudioFile) -> Mapping: file.setdefault("uri", "waveform") elif "audio" in file: - if isinstance(file["audio"], IOBase): return file @@ -177,7 +175,6 @@ def validate_file(file: AudioFile) -> Mapping: file.setdefault("uri", path.stem) else: - raise ValueError( "Neither 'waveform' nor 'audio' is available for this file." ) @@ -185,7 +182,6 @@ def validate_file(file: AudioFile) -> Mapping: return file def __init__(self, sample_rate=None, mono=None): - super().__init__() self.sample_rate = sample_rate self.mono = mono @@ -257,6 +253,18 @@ def get_duration(self, file: AudioFile) -> float: return frames / sample_rate + def get_num_samples(self, duration: float, sample_rate: int = None) -> int: + """Deterministic number of samples from duration and sample rate""" + + sample_rate = sample_rate or self.sample_rate + + if sample_rate is None: + raise ValueError( + "`sample_rate` must be provided to compute number of samples." + ) + + return math.floor(duration * sample_rate) + def __call__(self, file: AudioFile) -> Tuple[Tensor, int]: """Obtain waveform @@ -359,7 +367,6 @@ def crop( num_frames = end_frame - start_frame if mode == "raise": - if num_frames > frames: raise ValueError( f"requested fixed duration ({duration:6f}s, or {num_frames:d} frames) is longer " @@ -400,7 +407,6 @@ def crop( if isinstance(file["audio"], IOBase): file["audio"].seek(0) except RuntimeError: - if isinstance(file["audio"], IOBase): msg = "torchaudio failed to seek-and-read in file-like object." raise RuntimeError(msg) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 18b301086..5cb6c0e6b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,8 @@ import os import warnings +from dataclasses import dataclass +from functools import cached_property from importlib import import_module from pathlib import Path from typing import Any, Dict, List, Optional, Text, Tuple, Union @@ -49,6 +51,7 @@ Task, UnknownSpecificationsError, ) +from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.version import check_version CACHE_DIR = os.getenv( @@ -59,195 +62,16 @@ HF_LIGHTNING_CONFIG_NAME = "config.yaml" +# NOTE: needed to backward compatibility to load models trained before pyannote.audio 3.x class Introspection: - """Model introspection + pass - Parameters - ---------- - min_num_samples: int - Minimum number of input samples - min_num_frames: int - Corresponding minimum number of output frames - inc_num_samples: int - Number of input samples leading to an increase of number of output frames - inc_num_frames: int - Corresponding increase in number of output frames - dimension: int - Output dimension - sample_rate: int - Expected input sample rate - - Usage - ----- - >>> introspection = Introspection.from_model(model) - >>> isinstance(introspection.frames, SlidingWindow) - >>> num_samples = 16000 # 1s at 16kHz - >>> num_frames, dimension = introspection(num_samples) - """ - - def __init__( - self, - min_num_samples: int, - min_num_frames: int, - inc_num_samples: int, - inc_num_frames: int, - dimension: int, - sample_rate: int, - ): - super().__init__() - self.min_num_samples = min_num_samples - self.min_num_frames = min_num_frames - self.inc_num_samples = inc_num_samples - self.inc_num_frames = inc_num_frames - self.dimension = dimension - self.sample_rate = sample_rate - - @classmethod - def from_model(cls, model: "Model", task: str = None) -> Introspection: - - specifications = model.specifications - if task is not None: - specifications = specifications[task] - - example_input_array = model.example_input_array - batch_size, num_channels, num_samples = example_input_array.shape - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - - # dichotomic search of "min_num_samples" - lower, upper, min_num_samples = 1, num_samples, None - while True: - num_samples = (lower + upper) // 2 - try: - with torch.no_grad(): - frames = model(example_input_array[:, :, :num_samples]) - if task is not None: - frames = frames[task] - except Exception: - lower = num_samples - else: - min_num_samples = num_samples - if specifications.resolution == Resolution.FRAME: - _, min_num_frames, dimension = frames.shape - elif specifications.resolution == Resolution.CHUNK: - _, dimension = frames.shape - else: - # should never happen - pass - upper = num_samples - - if lower + 1 == upper: - break - - # if "min_num_samples" is still None at this point, it means that - # the forward pass always failed and raised an exception. most likely, - # it means that there is a problem with the model definition. - # we try again without catching the exception to help the end user debug - # their model - if min_num_samples is None: - frames = model(example_input_array) - - # corner case for chunk-level tasks - if specifications.resolution == Resolution.CHUNK: - return cls( - min_num_samples=min_num_samples, - min_num_frames=1, - inc_num_samples=0, - inc_num_frames=0, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - # search reasonable upper bound for "inc_num_samples" - while True: - num_samples = 2 * min_num_samples - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - with torch.no_grad(): - frames = model(example_input_array) - if task is not None: - frames = frames[task] - num_frames = frames.shape[1] - if num_frames > min_num_frames: - break - - # dichotomic search of "inc_num_samples" - lower, upper = min_num_samples, num_samples - while True: - num_samples = (lower + upper) // 2 - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - with torch.no_grad(): - frames = model(example_input_array) - if task is not None: - frames = frames[task] - num_frames = frames.shape[1] - if num_frames > min_num_frames: - inc_num_frames = num_frames - min_num_frames - inc_num_samples = num_samples - min_num_samples - upper = num_samples - else: - lower = num_samples - - if lower + 1 == upper: - break - return cls( - min_num_samples=min_num_samples, - min_num_frames=min_num_frames, - inc_num_samples=inc_num_samples, - inc_num_frames=inc_num_frames, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - def __call__(self, num_samples: int) -> Tuple[int, int]: - """Predict output shape, given number of input samples - - Parameters - ---------- - num_samples : int - Number of input samples. - - Returns - ------- - num_frames : int - Number of output frames - dimension : int - Dimension of output frames - """ - - if num_samples < self.min_num_samples: - return 0, self.dimension - - return ( - self.min_num_frames - + self.inc_num_frames - * ((num_samples - self.min_num_samples + 1) // self.inc_num_samples), - self.dimension, - ) - - @property - def frames(self) -> SlidingWindow: - # HACK to support model trained before 'sample_rate' was an Introspection attribute - sample_rate = getattr(self, "sample_rate", 16000) - step = (self.inc_num_samples / self.inc_num_frames) / sample_rate - return SlidingWindow(start=0.0, step=step, duration=step) +@dataclass +class Output: + num_frames: int + dimension: int + frames: SlidingWindow class Model(pl.LightningModule): @@ -281,31 +105,26 @@ def __init__( self.audio = Audio(sample_rate=self.hparams.sample_rate, mono="downmix") @property - def example_input_array(self) -> torch.Tensor: - batch_size = 3 if self.task is None else self.task.batch_size - duration = 2.0 if self.task is None else self.task.duration - - return torch.randn( - ( - batch_size, - self.hparams.num_channels, - int(self.hparams.sample_rate * duration), - ), - device=self.device, - ) - - @property - def task(self): + def task(self) -> Task: return self._task @task.setter - def task(self, task): - self._task = task - del self.introspection + def task(self, task: Task): + # reset (cached) properties when task changes del self.specifications + try: + del self.example_output + except AttributeError: + pass + self._task = task + + def build(self): + # use this method to add task-dependent layers to the model + # (e.g. the final classification and activation layers) + pass @property - def specifications(self): + def specifications(self) -> Union[Specifications, Tuple[Specifications]]: if self.task is None: try: specifications = self._specifications @@ -330,7 +149,22 @@ def specifications(self): return specifications @specifications.setter - def specifications(self, specifications): + def specifications( + self, specifications: Union[Specifications, Tuple[Specifications]] + ): + if not isinstance(specifications, (Specifications, tuple)): + raise ValueError( + "Only regular specifications or tuple of specifications are supported." + ) + + durations = set(s.duration for s in specifications) + if len(durations) > 1: + raise ValueError("All tasks must share the same (maximum) duration.") + + min_durations = set(s.min_duration for s in specifications) + if len(min_durations) > 1: + raise ValueError("All tasks must share the same minimum duration.") + self._specifications = specifications @specifications.deleter @@ -338,39 +172,53 @@ def specifications(self): if hasattr(self, "_specifications"): del self._specifications - def build(self): - # use this method to add task-dependent layers to the model - # (e.g. the final classification and activation layers) - pass + def __example_input_array(self, duration: Optional[float] = None) -> torch.Tensor: + duration = duration or next(iter(self.specifications)).duration + return torch.randn( + ( + 1, + self.hparams.num_channels, + self.audio.get_num_samples(duration), + ), + device=self.device, + ) @property - def introspection(self) -> Introspection: - """Introspection - - Returns - ------- - introspection: Introspection - Model introspection - """ - - if not hasattr(self, "_introspection"): - self._introspection = Introspection.from_model(self) - - return self._introspection + def example_input_array(self) -> torch.Tensor: + return self.__example_input_array() + + @cached_property + def example_output(self) -> Union[Output, Tuple[Output]]: + """Example output""" + example_input_array = self.__example_input_array() + with torch.inference_mode(): + example_output = self(example_input_array) + + def __example_output( + example_output: torch.Tensor, + specifications: Specifications = None, + ) -> Output: + _, num_frames, dimension = example_output.shape + + if specifications.resolution == Resolution.FRAME: + frame_duration = specifications.duration / num_frames + frames = SlidingWindow(step=frame_duration, duration=frame_duration) + else: + frames = None - @introspection.setter - def introspection(self, introspection): - self._introspection = introspection + return Output( + num_frames=num_frames, + dimension=dimension, + frames=frames, + ) - @introspection.deleter - def introspection(self): - if hasattr(self, "_introspection"): - del self._introspection + return map_with_specifications( + self.specifications, __example_output, example_output + ) def setup(self, stage=None): - if stage == "fit": - self.task.setup() + self.task.setup_metadata() # list of layers before adding task-dependent layers before = set((name, id(module)) for name, module in self.named_modules()) @@ -411,8 +259,8 @@ def setup(self, stage=None): # setup custom validation metrics self.task.setup_validation_metric() - # this is to make sure introspection is performed here, once and for all - _ = self.introspection + # cache for later (and to avoid later CUDA error with multiprocessing) + _ = self.example_output # list of layers after adding task-dependent layers after = set((name, id(module)) for name, module in self.named_modules()) @@ -421,7 +269,6 @@ def setup(self, stage=None): self.task_dependent = list(name for name, _ in after - before) def on_save_checkpoint(self, checkpoint): - # put everything pyannote.audio-specific under pyannote.audio # to avoid any future conflicts with pytorch-lightning updates checkpoint["pyannote.audio"] = { @@ -433,12 +280,10 @@ def on_save_checkpoint(self, checkpoint): "module": self.__class__.__module__, "class": self.__class__.__name__, }, - "introspection": self.introspection, "specifications": self.specifications, } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): - check_version( "pyannote.audio", checkpoint["pyannote.audio"]["versions"]["pyannote.audio"], @@ -462,43 +307,17 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]): self.specifications = checkpoint["pyannote.audio"]["specifications"] + # add task-dependent (e.g. final classifier) layers self.setup() - self.introspection = checkpoint["pyannote.audio"]["introspection"] - - def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + def forward( + self, waveforms: torch.Tensor, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: msg = "Class {self.__class__.__name__} should define a `forward` method." raise NotImplementedError(msg) - def helper_default_activation(self, specifications: Specifications) -> nn.Module: - """Helper function for default_activation - - Parameters - ---------- - specifications: Specifications - Task specification. - - Returns - ------- - activation : nn.Module - Default activation function. - """ - - if specifications.problem == Problem.BINARY_CLASSIFICATION: - return nn.Sigmoid() - - elif specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: - return nn.LogSoftmax(dim=-1) - - elif specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: - return nn.Sigmoid() - - else: - msg = "TODO: implement default activation for other types of problems" - raise NotImplementedError(msg) - # convenience function to automate the choice of the final activation function - def default_activation(self) -> nn.Module: + def default_activation(self) -> Union[nn.Module, Tuple[nn.Module]]: """Guess default activation function according to task specification * sigmoid for binary classification @@ -507,10 +326,25 @@ def default_activation(self) -> nn.Module: Returns ------- - activation : nn.Module + activation : (tuple of) nn.Module Activation. """ - return self.helper_default_activation(self.specifications) + + def __default_activation(specifications: Specifications = None) -> nn.Module: + if specifications.problem == Problem.BINARY_CLASSIFICATION: + return nn.Sigmoid() + + elif specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: + return nn.LogSoftmax(dim=-1) + + elif specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: + return nn.Sigmoid() + + else: + msg = "TODO: implement default activation for other types of problems" + raise NotImplementedError(msg) + + return map_with_specifications(self.specifications, __default_activation) # training data logic is delegated to the task because the # model does not really need to know how it is being used. @@ -535,9 +369,7 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) - def _helper_up_to( - self, module_name: Text, requires_grad: bool = False - ) -> List[Text]: + def __up_to(self, module_name: Text, requires_grad: bool = False) -> List[Text]: """Helper function for freeze_up_to and unfreeze_up_to""" tokens = module_name.split(".") @@ -594,7 +426,7 @@ def freeze_up_to(self, module_name: Text) -> List[Text]: If your model does not follow a sequential structure, you might want to use freeze_by_name for more control. """ - return self._helper_up_to(module_name, requires_grad=False) + return self.__up_to(module_name, requires_grad=False) def unfreeze_up_to(self, module_name: Text) -> List[Text]: """Unfreeze model up to specific module @@ -619,9 +451,9 @@ def unfreeze_up_to(self, module_name: Text) -> List[Text]: If your model does not follow a sequential structure, you might want to use freeze_by_name for more control. """ - return self._helper_up_to(module_name, requires_grad=True) + return self.__up_to(module_name, requires_grad=True) - def _helper_by_name( + def __by_name( self, modules: Union[List[Text], Text], recurse: bool = True, @@ -636,7 +468,6 @@ def _helper_by_name( modules = [modules] for name, module in ModelSummary(self, max_depth=-1).named_modules: - if name not in modules: continue @@ -678,7 +509,7 @@ def freeze_by_name( ValueError if at least one of `modules` does not exist. """ - return self._helper_by_name( + return self.__by_name( modules, recurse=recurse, requires_grad=False, @@ -709,7 +540,7 @@ def unfreeze_by_name( ValueError if at least one of `modules` does not exist. """ - return self._helper_by_name(modules, recurse=recurse, requires_grad=True) + return self.__by_name(modules, recurse=recurse, requires_grad=True) @classmethod def from_pretrained( @@ -826,7 +657,6 @@ def from_pretrained( # HACK do not use it. Fails silently in case model does not # HACK have a config.yaml file. try: - _ = hf_hub_download( model_id, HF_LIGHTNING_CONFIG_NAME, diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 9bf93bf1c..1edfbc35c 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -72,9 +72,11 @@ class Specifications: problem: Problem resolution: Resolution - # chunk duration in seconds. - # use None for variable-length chunks - duration: Optional[float] = None + # (maximum) chunk duration in seconds + duration: float + + # (for variable-duration tasks only) minimum chunk duration in seconds + min_duration: Optional[float] = None # use that many seconds on the left- and rightmost parts of each chunk # to warm up the model. This is mostly useful for segmentation tasks. @@ -95,7 +97,7 @@ class Specifications: permutation_invariant: bool = False @cached_property - def powerset(self): + def powerset(self) -> bool: if self.powerset_max_classes is None: return False @@ -118,6 +120,12 @@ def num_powerset_classes(self) -> int: ) ) + def __len__(self): + return 1 + + def __iter__(self): + yield self + class TrainDataset(IterableDataset): def __init__(self, task: Task): @@ -191,7 +199,7 @@ class Task(pl.LightningDataModule): Attributes ---------- - specifications : Specifications or dict of Specifications + specifications : Specifications or tuple of Specifications Task specifications (available after `Task.setup` has been called.) """ @@ -260,7 +268,28 @@ def prepare_data(self): """ pass - def setup(self, stage: Optional[str] = None): + @property + def specifications(self) -> Union[Specifications, Tuple[Specifications]]: + # setup metadata on-demand the first time specifications are requested and missing + if not hasattr(self, "_specifications"): + self.setup_metadata() + return self._specifications + + @specifications.setter + def specifications( + self, specifications: Union[Specifications, Tuple[Specifications]] + ): + self._specifications = specifications + + @property + def has_setup_metadata(self): + return getattr(self, "_has_setup_metadata", False) + + @has_setup_metadata.setter + def has_setup_metadata(self, value: bool): + self._has_setup_metadata = value + + def setup_metadata(self): """Called at the beginning of training at the very beginning of Model.setup(stage="fit") Notes @@ -270,7 +299,10 @@ def setup(self, stage: Optional[str] = None): If `specifications` attribute has not been set in `__init__`, `setup` is your last chance to set it. """ - pass + + if not self.has_setup_metadata: + self.setup() + self.has_setup_metadata = True def setup_loss_func(self): pass @@ -362,6 +394,11 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): {"loss": loss} """ + if isinstance(self.specifications, tuple): + raise NotImplementedError( + "Default training/validation step is not implemented for multi-task." + ) + # forward pass y_pred = self.model(batch["X"]) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 1b68a32a9..5af3734b1 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -80,7 +80,6 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, ): - super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) @@ -140,7 +139,6 @@ def __init__( ) def build(self): - if self.hparams.linear["num_layers"] > 0: in_features = self.hparams.linear["hidden_size"] else: @@ -148,6 +146,9 @@ def build(self): 2 if self.hparams.lstm["bidirectional"] else 1 ) + if isinstance(self.specifications, tuple): + raise ValueError("PyanNet does not support multi-tasking.") + if self.specifications.powerset: out_features = self.specifications.num_powerset_classes else: diff --git a/pyannote/audio/models/segmentation/debug.py b/pyannote/audio/models/segmentation/debug.py index 498faee27..89512320c 100644 --- a/pyannote/audio/models/segmentation/debug.py +++ b/pyannote/audio/models/segmentation/debug.py @@ -39,7 +39,6 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, ): - super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) self.mfcc = MFCC( @@ -60,7 +59,16 @@ def __init__( def build(self): # define task-dependent layers - self.classifier = nn.Linear(32 * 2, len(self.specifications.classes)) + + if isinstance(self.specifications, tuple): + raise ValueError("SimpleSegmentationModel does not support multi-tasking.") + + if self.specifications.powerset: + out_features = self.specifications.num_powerset_classes + else: + out_features = len(self.specifications.classes) + + self.classifier = nn.Linear(32 * 2, out_features) self.activation = self.default_activation() def forward(self, waveforms: torch.Tensor) -> torch.Tensor: diff --git a/pyannote/audio/pipelines/overlapped_speech_detection.py b/pyannote/audio/pipelines/overlapped_speech_detection.py index 9b14ee10f..064cae1be 100644 --- a/pyannote/audio/pipelines/overlapped_speech_detection.py +++ b/pyannote/audio/pipelines/overlapped_speech_detection.py @@ -128,7 +128,7 @@ def __init__( # load model model = get_model(segmentation, use_auth_token=use_auth_token) - if model.introspection.dimension > 1: + if model.example_output.dimension > 1: inference_kwargs["pre_aggregation_hook"] = lambda scores: np.partition( scores, -2, axis=-1 )[:, :, -2, np.newaxis] diff --git a/pyannote/audio/pipelines/resegmentation.py b/pyannote/audio/pipelines/resegmentation.py index 57cf9004b..bb71abf22 100644 --- a/pyannote/audio/pipelines/resegmentation.py +++ b/pyannote/audio/pipelines/resegmentation.py @@ -88,7 +88,6 @@ def __init__( der_variant: dict = None, use_auth_token: Union[Text, None] = None, ): - super().__init__() self.segmentation = segmentation @@ -96,7 +95,7 @@ def __init__( model: Model = get_model(segmentation, use_auth_token=use_auth_token) self._segmentation = Inference(model) - self._frames = self._segmentation.model.introspection.frames + self._frames = self._segmentation.model.example_output.frames self._audio = model.audio diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 6bc81f28a..8cf30f3b9 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -136,7 +136,7 @@ def __init__( skip_aggregation=True, batch_size=segmentation_batch_size, ) - self._frames: SlidingWindow = self._segmentation.model.introspection.frames + self._frames: SlidingWindow = self._segmentation.model.example_output.frames if self._segmentation.model.specifications.powerset: self.segmentation = ParamDict( diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index d99537cbc..005c8964f 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -64,7 +64,6 @@ def __init__( embedding: Text = "nvidia/speakerverification_en_titanet_large", device: torch.device = None, ): - if not NEMO_IS_AVAILABLE: raise ImportError( f"'NeMo' must be installed to use '{embedding}' embeddings. " @@ -95,7 +94,6 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - input_signal = torch.rand(1, self.sample_rate).to(self.device) input_signal_length = torch.tensor([self.sample_rate]).to(self.device) _, embeddings = self.model_( @@ -110,7 +108,6 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 while lower + 1 < upper: @@ -157,7 +154,6 @@ def __call__( wav_lens = signals.shape[1] * torch.ones(batch_size) else: - batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks @@ -234,7 +230,6 @@ def __init__( device: torch.device = None, use_auth_token: Union[Text, None] = None, ): - if not SPEECHBRAIN_IS_AVAILABLE: raise ImportError( f"'speechbrain' must be installed to use '{embedding}' embeddings. " @@ -291,19 +286,19 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - - lower, upper = 2, round(0.5 * self.sample_rate) - middle = (lower + upper) // 2 - while lower + 1 < upper: - try: - _ = self.classifier_.encode_batch( - torch.randn(1, middle).to(self.device) - ) - upper = middle - except RuntimeError: - lower = middle - + with torch.inference_mode(): + lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + _ = self.classifier_.encode_batch( + torch.randn(1, middle).to(self.device) + ) + upper = middle + except RuntimeError: + lower = middle + + middle = (lower + upper) // 2 return upper @@ -334,7 +329,6 @@ def __call__( wav_lens = signals.shape[1] * torch.ones(batch_size) else: - batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks @@ -440,7 +434,7 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - return self.model_.introspection.dimension + return self.model_.example_output.dimension @cached_property def metric(self) -> str: @@ -448,12 +442,24 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - return self.model_.introspection.min_num_samples + with torch.inference_mode(): + lower, upper = 2, round(0.5 * self.sample_rate) + middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + _ = self.model_(torch.randn(1, 1, middle).to(self.device)) + upper = middle + except RuntimeError: + lower = middle + + middle = (lower + upper) // 2 + + return upper def __call__( self, waveforms: torch.Tensor, masks: torch.Tensor = None ) -> np.ndarray: - with torch.no_grad(): + with torch.inference_mode(): if masks is None: embeddings = self.model_(waveforms.to(self.device)) else: @@ -572,7 +578,6 @@ def __init__( ) def apply(self, file: AudioFile) -> np.ndarray: - device = self.embedding_model_.device # read audio file and send it to GPU @@ -598,7 +603,6 @@ def main( embedding: str = "pyannote/embedding", segmentation: str = None, ): - import typer from pyannote.database import FileFinder, get_protocol from pyannote.metrics.binary_classification import det_curve @@ -616,7 +620,6 @@ def main( trials = getattr(protocol, f"{subset}_trial")() for t, trial in enumerate(tqdm(trials)): - audio1 = trial["file1"]["audio"] if audio1 not in emb: emb[audio1] = pipeline(audio1) diff --git a/pyannote/audio/pipelines/utils/oracle.py b/pyannote/audio/pipelines/utils/oracle.py index 486b09274..44b4ded61 100644 --- a/pyannote/audio/pipelines/utils/oracle.py +++ b/pyannote/audio/pipelines/utils/oracle.py @@ -39,7 +39,7 @@ def oracle_segmentation( Simulates inference based on an (imaginary) oracle segmentation model: >>> oracle = Model.from_pretrained("oracle") - >>> assert frames == oracle.introspection.frames + >>> assert frames == oracle.example_output.frames >>> inference = Inference(oracle, duration=window.duration, step=window.step, skip_aggregation=True) >>> oracle_segmentation = inference(file) diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index f5e41d3ee..da164f04e 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -21,7 +21,7 @@ # SOFTWARE. import math -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Sequence, Union import torch import torch.nn.functional as F @@ -75,13 +75,10 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int): self.batch_size_ = batch_size - def setup(self, stage: Optional[str] = None): + def setup(self): # loop over the training set, remove annotated regions shorter than # chunk duration, and keep track of the reference annotations, per class. - # FIXME: it looks like this time consuming step is called multiple times. - # it should not be... - self._train = dict() desc = f"Loading {self.protocol.name} training labels" @@ -118,6 +115,7 @@ def setup(self, stage: Optional[str] = None): problem=Problem.REPRESENTATION, resolution=Resolution.CHUNK, duration=self.duration, + min_duration=self.min_duration, classes=sorted(self._train), ) @@ -151,6 +149,7 @@ def train__iter__(self): classes = list(self.specifications.classes) + # select batch-wise duration at random batch_duration = rng.uniform(self.min_duration, self.duration) num_samples = 0 diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 3db93824d..142245ae8 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -25,7 +25,7 @@ import random import warnings from collections import defaultdict -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Sequence, Union import matplotlib.pyplot as plt import numpy as np @@ -71,14 +71,8 @@ def get_file(self, file_id): return file - def setup(self, stage: Optional[str] = None): - """Setup method - - Parameters - ---------- - stage : {'fit', 'validate', 'test'}, optional - Setup stage. Defaults to 'fit'. - """ + def setup(self): + """Setup""" # duration of training chunks # TODO: handle variable duration case diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index da6104386..5588ccdff 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -25,7 +25,7 @@ import numpy as np import torch import torch.nn.functional as F -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from pyannote.database.protocol import SegmentationProtocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform @@ -119,14 +119,15 @@ def __init__( # classes should be detected. therefore, we postpone the definition of # specifications to setup() - def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) + def setup(self): + super().setup() self.specifications = Specifications( classes=self.classes, problem=Problem.MULTI_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, ) @@ -167,14 +168,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - - # TODO: this should be cached - # use model introspection to predict how many frames it will output - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -185,19 +178,23 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets (-1 for un-annotated classes) - y = -np.ones((num_frames, len(self.classes)), dtype=np.int8) + y = -np.ones( + (self.model.example_output.num_frames, len(self.classes)), dtype=np.int8 + ) y[:, self.annotated_classes[file_id]] = 0 for start, end, label in zip( start_idx, end_idx, chunk_annotations["global_label_idx"] ): y[start:end, label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=self.classes) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=self.classes + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 658c350a7..cd3711d61 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -24,7 +24,7 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric @@ -106,7 +106,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - super().__init__( protocol, duration=duration, @@ -122,6 +121,7 @@ def __init__( problem=Problem.BINARY_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[ "overlap", @@ -162,13 +162,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -179,17 +172,19 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets - y = np.zeros((num_frames, 1), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] += 1 y = 1 * (y > 1) - sample["y"] = SlidingWindowFeature(y, frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=["speech"] + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index b6838a71c..eac795a47 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -23,13 +23,13 @@ import math import warnings from collections import Counter -from typing import Dict, Literal, Optional, Sequence, Text, Tuple, Union +from typing import Dict, Literal, Sequence, Text, Tuple, Union import numpy as np import torch import torch.nn.functional from matplotlib import pyplot as plt -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database.protocol import SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -186,8 +186,8 @@ def __init__( self.weight = weight self.vad_loss = vad_loss - def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) + def setup(self): + super().setup() # estimate maximum number of speakers per chunk when not provided if self.max_speakers_per_chunk is None: @@ -276,6 +276,7 @@ def setup(self, stage: Optional[str] = None): else Problem.MONO_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], powerset_max_classes=self.max_speakers_per_frame, @@ -326,13 +327,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -343,9 +337,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -355,7 +349,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((num_frames, num_labels), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -366,7 +360,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): mapped_label = mapping[label] y[start:end, mapped_label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=labels) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=labels + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} @@ -553,11 +549,7 @@ def training_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) + multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target.float() @@ -686,11 +678,7 @@ def validation_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) + multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) # FIXME: handle case where target have too many speakers? diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 559ff24eb..967ea1f9b 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -23,7 +23,7 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric @@ -89,7 +89,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - super().__init__( protocol, duration=duration, @@ -108,6 +107,7 @@ def __init__( problem=Problem.BINARY_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[ "speech", @@ -144,13 +144,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -161,16 +154,18 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets - y = np.zeros((num_frames, 1), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=["speech"] + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/utils/multi_task.py b/pyannote/audio/utils/multi_task.py new file mode 100644 index 000000000..3886a0eeb --- /dev/null +++ b/pyannote/audio/utils/multi_task.py @@ -0,0 +1,59 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Any, Callable, Tuple, Union + +from pyannote.audio.core.model import Specifications + + +def map_with_specifications( + specifications: Union[Specifications, Tuple[Specifications]], + func: Callable, + *iterables, +) -> Union[Any, Tuple[Any]]: + """Compute the function using arguments from each of the iterables + + Returns a tuple if provided `specifications` is a tuple, + otherwise returns the function return value. + + Parameters + ---------- + specifications : (tuple of) Specifications + Specifications or tuple of specifications + func : callable + Function called for each specification with + `func(*iterables[i], specifications=specifications[i])` + *iterables : + List of iterables with same length as `specifications`. + + Returns + ------- + output : (tuple of) `func` return value(s) + """ + + if isinstance(specifications, Specifications): + return func(*iterables, specifications=specifications) + + return tuple( + func(*i, specifications=s) for s, *i in zip(specifications, *iterables) + ) diff --git a/pyannote/audio/utils/powerset.py b/pyannote/audio/utils/powerset.py index 215cb7946..0f5cfb5bc 100644 --- a/pyannote/audio/utils/powerset.py +++ b/pyannote/audio/utils/powerset.py @@ -85,25 +85,29 @@ def build_cardinality(self) -> torch.Tensor: return cardinality def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor: - """Convert (hard) predictions from powerset to multi-label + """Convert predictions from (soft) powerset to (hard) multi-label Parameter --------- powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor - Hard predictions in "powerset" space. + Soft predictions in "powerset" space. Returns ------- multi_label : (batch_size, num_frames, num_classes) torch.Tensor Hard predictions in "multi-label" space. - - Note - ---- - This method will not complain if `powerset` is provided a soft predictions - (e.g. the output of a softmax-ed classifier). However, in that particular - case, the resulting soft multi-label output will not make much sense. """ - return torch.matmul(powerset, self.mapping) + + hard_powerset = torch.nn.functional.one_hot( + torch.argmax(powerset, dim=-1), + self.num_powerset_classes, + ).float() + + return torch.matmul(hard_powerset, self.mapping) + + def forward(self, powerset: torch.Tensor) -> torch.Tensor: + """Alias for `to_multilabel`""" + return self.to_multilabel(powerset) def to_powerset(self, multilabel: torch.Tensor) -> torch.Tensor: """Convert (hard) predictions from multi-label to powerset diff --git a/pyannote/audio/utils/preview.py b/pyannote/audio/utils/preview.py index 6094c71cd..fcdf4d124 100644 --- a/pyannote/audio/utils/preview.py +++ b/pyannote/audio/utils/preview.py @@ -256,7 +256,7 @@ def make_frame(T: float): return IPythonVideo(video_path, embed=True) -def preview_training_samples( +def BROKEN_preview_training_samples( model: Model, blank: float = 1.0, video_fps: int = 5, diff --git a/tests/inference_test.py b/tests/inference_test.py index 807f94cc1..bd5040394 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -1,13 +1,13 @@ import numpy as np import pytest import pytorch_lightning as pl +from pyannote.core import SlidingWindowFeature +from pyannote.database import FileFinder, get_protocol from pyannote.audio import Inference, Model from pyannote.audio.core.task import Resolution from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel from pyannote.audio.tasks import VoiceActivityDetection -from pyannote.core import SlidingWindowFeature -from pyannote.database import FileFinder, get_protocol HF_SAMPLE_MODEL_ID = "pyannote/TestModelForContinuousIntegration" @@ -29,8 +29,8 @@ def trained(): ) vad = VoiceActivityDetection(protocol, duration=2.0, batch_size=16, num_workers=4) model = SimpleSegmentationModel(task=vad) - trainer = pl.Trainer(fast_dev_run=True) - trainer.fit(model, vad) + trainer = pl.Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) return protocol, model @@ -91,7 +91,6 @@ def test_on_file_path(trained): def test_skip_aggregation(pretrained_model, dev_file): - inference = Inference(pretrained_model, skip_aggregation=True) scores = inference(dev_file) assert len(scores.data.shape) == 3 diff --git a/tests/test_train.py b/tests/test_train.py index 79e7f071a..7a7bfe338 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -20,125 +20,119 @@ def protocol(): def test_train_segmentation(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_train_voice_activity_detection(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_train_overlapped_speech_detection(protocol): overlapped_speech_detection = OverlappedSpeechDetection(protocol) model = SimpleSegmentationModel(task=overlapped_speech_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_with_task_that_does_not_need_setup_for_specs(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_with_task_that_needs_setup_for_specs(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_with_task_that_does_not_need_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_with_task_that_needs_setup_for_specs(protocol): - voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_freeze_with_task_that_needs_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs(protocol): - vad = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=vad) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) vad = VoiceActivityDetection(protocol) model.task = vad model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_freeze_with_task_that_does_not_need_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_freeze_with_task_that_needs_setup_for_specs(protocol): - voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) diff --git a/tutorials/add_your_own_task.ipynb b/tutorials/add_your_own_task.ipynb index b2053f459..251846957 100644 --- a/tutorials/add_your_own_task.ipynb +++ b/tutorials/add_your_own_task.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -32,6 +33,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -48,6 +50,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -57,6 +60,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -82,6 +86,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -125,6 +130,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -176,54 +182,52 @@ " augmentation=augmentation,\n", " )\n", "\n", - " def setup(self, stage=None):\n", - "\n", - " if stage == \"fit\":\n", - "\n", - " # load metadata for training subset\n", - " self.train_metadata_ = list()\n", - " for training_file in self.protocol.train():\n", - " self.training_metadata_.append({\n", - " # path to audio file (str)\n", - " \"audio\": training_file[\"audio\"],\n", - " # duration of audio file (float)\n", - " \"duration\": training_file[\"duration\"],\n", - " # reference annotation (pyannote.core.Annotation)\n", - " \"annotation\": training_file[\"annotation\"],\n", - " })\n", - "\n", - " # gather the list of classes\n", - " classes = set()\n", - " for training_file in self.train_metadata_:\n", - " classes.update(training_file[\"reference\"].labels())\n", - " classes = sorted(classes)\n", - "\n", - " # specify the addressed problem\n", - " self.specifications = Specifications(\n", - " # it is a multi-label classification problem\n", - " problem=Problem.MULTI_LABEL_CLASSIFICATION,\n", - " # we expect the model to output one prediction \n", - " # for the whole chunk\n", - " resolution=Resolution.CHUNK,\n", - " # the model will ingest chunks with that duration (in seconds)\n", - " duration=self.duration,\n", - " # human-readable names of classes\n", - " classes=classes)\n", - "\n", - " # `has_validation` is True iff protocol defines a development set\n", - " if not self.has_validation:\n", - " return\n", - "\n", - " # load metadata for validation subset\n", - " self.validation_metadata_ = list()\n", - " for validation_file in self.protocol.development():\n", - " self.validation_metadata_.append({\n", - " \"audio\": validation_file[\"audio\"],\n", - " \"num_samples\": math.floor(validation_file[\"duration\"] / self.duration),\n", - " \"annotation\": validation_file[\"annotation\"],\n", - " })\n", - " \n", - " \n", + " def setup(self):\n", + "\n", + " # load metadata for training subset\n", + " self.train_metadata_ = list()\n", + " for training_file in self.protocol.train():\n", + " self.training_metadata_.append({\n", + " # path to audio file (str)\n", + " \"audio\": training_file[\"audio\"],\n", + " # duration of audio file (float)\n", + " \"duration\": training_file[\"duration\"],\n", + " # reference annotation (pyannote.core.Annotation)\n", + " \"annotation\": training_file[\"annotation\"],\n", + " })\n", + "\n", + " # gather the list of classes\n", + " classes = set()\n", + " for training_file in self.train_metadata_:\n", + " classes.update(training_file[\"reference\"].labels())\n", + " classes = sorted(classes)\n", + "\n", + " # specify the addressed problem\n", + " self.specifications = Specifications(\n", + " # it is a multi-label classification problem\n", + " problem=Problem.MULTI_LABEL_CLASSIFICATION,\n", + " # we expect the model to output one prediction \n", + " # for the whole chunk\n", + " resolution=Resolution.CHUNK,\n", + " # the model will ingest chunks with that duration (in seconds)\n", + " duration=self.duration,\n", + " # human-readable names of classes\n", + " classes=classes)\n", + "\n", + " # `has_validation` is True iff protocol defines a development set\n", + " if not self.has_validation:\n", + " return\n", + "\n", + " # load metadata for validation subset\n", + " self.validation_metadata_ = list()\n", + " for validation_file in self.protocol.development():\n", + " self.validation_metadata_.append({\n", + " \"audio\": validation_file[\"audio\"],\n", + " \"num_samples\": math.floor(validation_file[\"duration\"] / self.duration),\n", + " \"annotation\": validation_file[\"annotation\"],\n", + " })\n", + " \n", + " \n", "\n", " def train__iter__(self):\n", " # this method generates training samples, one at a time, \"ad infinitum\". each worker \n", diff --git a/tutorials/overlapped_speech_detection.ipynb b/tutorials/overlapped_speech_detection.ipynb index 78c6372cb..1ad5d4090 100644 --- a/tutorials/overlapped_speech_detection.ipynb +++ b/tutorials/overlapped_speech_detection.ipynb @@ -20,6 +20,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -39,6 +40,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -49,6 +51,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -84,6 +87,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -103,6 +107,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -110,6 +115,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -130,6 +136,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -147,6 +154,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -161,10 +169,11 @@ "source": [ "import pytorch_lightning as pl\n", "trainer = pl.Trainer(max_epochs=10)\n", - "trainer.fit(model, osd)" + "trainer.fit(model)" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -185,6 +194,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -212,6 +222,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -219,6 +230,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -242,6 +254,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -258,6 +271,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -265,6 +279,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -297,6 +312,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ From 4eb719046bc9c2aa21f2f935d7a52d7f9c229327 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 12 Jun 2023 13:10:50 +0200 Subject: [PATCH 03/83] fix(inference): fix multi-task inference --- pyannote/audio/core/inference.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 703aa06cb..dcf21868d 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -262,19 +262,17 @@ def slide( step_size: int = round(self.step * sample_rate) _, num_samples = waveform.shape - frames = self.model.example_output.frames - def __frames( - frames, specifications: Optional[Specifications] = None + example_output, specifications: Optional[Specifications] = None ) -> SlidingWindow: if specifications.resolution == Resolution.CHUNK: return SlidingWindow(start=0.0, duration=self.duration, step=self.step) - return frames + return example_output.frames frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( self.model.specifications, __frames, - self.model.example_output.frames, + self.model.example_output, ) # prepare complete chunks From dcdfc15c5aa3692173f2ec29241f670127f6e8df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 15 Jun 2023 13:53:04 +0200 Subject: [PATCH 04/83] feat: update FAQtory default answer --- .faq/suggest.md | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/.faq/suggest.md b/.faq/suggest.md index 0a9233998..5fd8b252f 100644 --- a/.faq/suggest.md +++ b/.faq/suggest.md @@ -1,3 +1,5 @@ +Thank you for your issue. + {%- if questions -%} {% if questions|length == 1 %} We found the following entry in the [FAQ]({{ faq_url }}) which you may find helpful: @@ -9,12 +11,22 @@ We found the following entries in the [FAQ]({{ faq_url }}) which you may find he - [{{ question.title }}]({{ faq_url }}#{{ question.slug }}) {%- endfor %} -Feel free to close this issue if you found an answer in the FAQ. Otherwise, please give us a little time to review. - {%- else -%} -Thank you for your issue. Give us a little time to review it. - -PS. You might want to check the [FAQ]({{ faq_url }}) if you haven't done so already. +You might want to check the [FAQ]({{ faq_url }}) if you haven't done so already. {%- endif %} -This is an automated reply, generated by [FAQtory](https://github.com/willmcgugan/faqtory) +Feel free to close this issue if you found an answer in the FAQ. + +If your issue is a feature request, please read [this](https://xyproblem.info/) first and update your request accordingly, if needed. + +If your issue is a bug report, please provide a [minimum reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) as a link to a self-contained [Google Colab](https://colab.research.google.com/) notebook containing everthing needed to reproduce the bug: + - installation + - data preparation + - model download + - etc. + +Providing an MRE will increase your chance of getting an answer from the community (either maintainers or other power users). + +[We](https://herve.niderb.fr) also offer paid scientific consulting services around speaker diarization (and speech processing in general). + +> This is an automated reply, generated by [FAQtory](https://github.com/willmcgugan/faqtory) From 87f49f9f60c6f46edbc464e515a1522039ca3e76 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Sat, 17 Jun 2023 15:33:49 +0200 Subject: [PATCH 05/83] add draft version of the joint diarization and embedding tasks --- pyannote/audio/tasks/__init__.py | 3 + .../speaker_diarization_and_embedding.py | 968 ++++++++++++++++++ 2 files changed, 971 insertions(+) create mode 100644 pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 6cbba258f..71936986d 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -28,6 +28,8 @@ ) from .embedding.arcface import SupervisedRepresentationLearningWithArcFace # isort:skip +from .joint_task.speaker_diarization_and_embedding import JointSpeakerDiarizationAndEmbedding + # Segmentation has been renamed to SpeakerDiarization but we keep Segmentation here for backward compatibility Segmentation = SpeakerDiarization @@ -41,4 +43,5 @@ "MultiLabelSegmentation", "SpeakerEmbedding", "Segmentation", + "JointSpeakerDiarizationAndEmbedding", ] diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py new file mode 100644 index 000000000..1bf5dae36 --- /dev/null +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -0,0 +1,968 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from collections import defaultdict +import itertools +import random +import numpy as np +import torch +from typing import Literal, Union, Sequence, Dict +import warnings + +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchaudio.backend.common import AudioMetaData +from torchmetrics import Metric +from torchmetrics.classification import BinaryAUROC + +from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.audio.core.task import Problem, Resolution, Specifications, Task +from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss +from pyannote.audio.utils.permutation import permutate +from pyannote.audio.utils.powerset import Powerset +from pyannote.audio.utils.random import create_rng_for_worker +from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol +from pyannote.database.protocol.protocol import Scope, Subset +from pyannote.audio.torchmetrics.classification import EqualErrorRate +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + OptimalDiarizationErrorRate, + OptimalDiarizationErrorRateThreshold, + OptimalFalseAlarmRate, + OptimalMissedDetectionRate, + OptimalSpeakerConfusionRate, + SpeakerConfusionRate, +) + +Subtask = Literal["diarization", "embedding"] + +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) +Subtasks = list(Subtask.__args__) + + +class JointSpeakerDiarizationAndEmbedding(Task): + """Joint speaker diarization and embedding task + + Usage + ----- + load a meta protocol containing both diarization (e.g. X.SpeakerDiarization.Pretraining) + and verification (e.g. VoxCeleb.SpeakerVerification.VoxCeleb) datasets + >>> from pyannote.database import registry + >>> protocol = registry.get_protocol(...) + + instantiate task + >>> task = JointSpeakerDiarizationAndEmbedding(protocol) + + instantiate multi-task model + >>> model = JointSpeakerDiarizationAndEmbeddingModel() + >>> model.task = task + + train as usual... + + """ + + def __init__( + self, + protocol, + duration: float = 5.0, + max_speaker_per_chunk: int = 3, + max_speakers_per_frame: int = 2, + batch_size: int = 32, + database_ratio : float = 0.5, + num_workers: int = None, + pin_memory: bool = False, + augmentation: BaseWaveformTransform = None + ) -> None: + super().__init__( + protocol, + duration=duration, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + augmentation=augmentation, + ) + + self.max_speaker_per_chunk = max_speaker_per_chunk + self.max_speakers_per_frame = max_speakers_per_frame + self.database_ratio = database_ratio + + + # keep track of the use of database available in the meta protocol + # * embedding databases are those with global speaker label scope + # * diarization databases are those with file or database speaker label scope + self.embedding_database_files = [] + self.diarization_database_files = [] + self._train = {} + + def get_file(self, file_id): + + file = dict() + + file["audio"] = str(self.audios[file_id], encoding="utf-8") + + _audio_info = self.audio_infos[file_id] + _encoding = self.audio_encodings[file_id] + + sample_rate = _audio_info["sample_rate"] + num_frames = _audio_info["num_frames"] + num_channels = _audio_info["num_channels"] + bits_per_sample = _audio_info["bits_per_sample"] + encoding = str(_encoding, encoding="utf-8") + file["torchaudio.info"] = AudioMetaData( + sample_rate=sample_rate, + num_frames=num_frames, + num_channels=num_channels, + bits_per_sample=bits_per_sample, + encoding=encoding, + ) + + return file + + def setup(self, stage="fit"): + """Setup method + + Parameters + ---------- + stage : {'fit', 'validate', 'test'}, optional + Setup stage. Defaults to 'fit'. + """ + + # duration of training chunks + # TODO: handle variable duration case + duration = getattr(self, "duration", 0.0) + + # list of possible values for each metadata key + metadata_unique_values = defaultdict(list) + + metadata_unique_values["subset"] = Subsets + metadata_unique_values["subtask"] = ["diarization", "embedding"] + + if isinstance(self.protocol, SpeakerDiarizationProtocol): + metadata_unique_values["scope"] = Scopes + + elif isinstance(self.protocol, SegmentationProtocol): + classes = getattr(self, "classes", list()) + + # make sure classes attribute exists (and set to None if it did not exist) + self.classes = getattr(self, "classes", None) + if self.classes is None: + classes = list() + # metadata_unique_values["classes"] = list(classes) + + audios = list() # list of path to audio files + audio_infos = list() + audio_encodings = list() + metadata = list() # list of metadata + + annotated_duration = list() # total duration of annotated regions (per file) + annotated_regions = list() # annotated regions + annotations = list() # actual annotations + annotated_classes = list() # list of annotated classes (per file) + unique_labels = list() + + if self.has_validation: + files_iter = itertools.chain( + self.protocol.train(), self.protocol.development() + ) + else: + files_iter = self.protocol.train() + + for file_id, file in enumerate(files_iter): + + # gather metadata and update metadata_unique_values so that each metadatum + # (e.g. source database or label) is represented by an integer. + metadatum = dict() + + # keep track of source database and subset (train, development, or test) + if file["database"] not in metadata_unique_values["database"]: + metadata_unique_values["database"].append(file["database"]) + metadatum["database"] = metadata_unique_values["database"].index( + file["database"] + ) + metadatum["subset"] = Subsets.index(file["subset"]) + + # keep track of speaker label scope (file, database, or global) for speaker diarization protocols + if isinstance(self.protocol, SpeakerDiarizationProtocol): + metadatum["scope"] = Scopes.index(file["scope"]) + # add the file to the embedding or diarization list according to the file database speaker + # labels scope + if file["scope"] == 'global': + self.embedding_database_files.append(file_id) + elif file["scope"] in ["database", "file"]: + self.diarization_database_files.append(file_id) + + # keep track of list of classes for regular segmentation protocols + # Different files may be annotated using a different set of classes + # (e.g. one database for speech/music/noise, and another one for male/female/child) + if isinstance(self.protocol, SegmentationProtocol): + + if "classes" in file: + local_classes = file["classes"] + else: + local_classes = file["annotation"].labels() + + # if task was not initialized with a fixed list of classes, + # we build it as the union of all classes found in files + if self.classes is None: + for klass in local_classes: + if klass not in classes: + classes.append(klass) + annotated_classes.append( + [classes.index(klass) for klass in local_classes] + ) + + # if task was initialized with a fixed list of classes, + # we make sure that all files use a subset of these classes + # if they don't, we issue a warning and ignore the extra classes + else: + extra_classes = set(local_classes) - set(self.classes) + if extra_classes: + warnings.warn( + f"Ignoring extra classes ({', '.join(extra_classes)}) found for file {file['uri']} ({file['database']}). " + ) + annotated_classes.append( + [ + self.classes.index(klass) + for klass in set(local_classes) & set(self.classes) + ] + ) + + remaining_metadata_keys = set(file) - set( + [ + "uri", + "database", + "subset", + "audio", + "torchaudio.info", + "scope", + "classes", + "annotation", + "annotated", + ] + ) + + # keep track of any other (integer or string) metadata provided by the protocol + # (e.g. a "domain" key for domain-adversarial training) + for key in remaining_metadata_keys: + + value = file[key] + + if isinstance(value, str): + if value not in metadata_unique_values[key]: + metadata_unique_values[key].append(value) + metadatum[key] = metadata_unique_values[key].index(value) + + elif isinstance(value, int): + metadatum[key] = value + + else: + warnings.warn( + f"Ignoring '{key}' metadata because of its type ({type(value)}). Only str and int are supported for now.", + category=UserWarning, + ) + + metadata.append(metadatum) + + database_unique_labels = list() + + # reset list of file-scoped labels + file_unique_labels = list() + + # path to audio file + audios.append(str(file["audio"])) + + # audio info + audio_info = file["torchaudio.info"] + audio_infos.append( + ( + audio_info.sample_rate, # sample rate + audio_info.num_frames, # number of frames + audio_info.num_channels, # number of channels + audio_info.bits_per_sample, # bits per sample + ) + ) + audio_encodings.append(audio_info.encoding) # encoding + + # annotated regions and duration + _annotated_duration = 0.0 + for segment in file["annotated"]: + + # skip annotated regions that are shorter than training chunk duration + if segment.duration < duration: + continue + + # append annotated region + annotated_region = ( + file_id, + segment.duration, + segment.start, + segment.end, + ) + annotated_regions.append(annotated_region) + + # increment annotated duration + _annotated_duration += segment.duration + + # append annotated duration + annotated_duration.append(_annotated_duration) + + # annotations + for segment, _, label in file["annotation"].itertracks(yield_label=True): + + # "scope" is provided by speaker diarization protocols to indicate + # whether speaker labels are local to the file ('file'), consistent across + # all files in a database ('database'), or globally consistent ('global') + + if "scope" in file: + + # 0 = 'file' + # 1 = 'database' + # 2 = 'global' + scope = Scopes.index(file["scope"]) + + # update list of file-scope labels + if label not in file_unique_labels: + file_unique_labels.append(label) + # and convert label to its (file-scope) index + file_label_idx = file_unique_labels.index(label) + + database_label_idx = global_label_idx = -1 + + if scope > 0: # 'database' or 'global' + + # update list of database-scope labels + if label not in database_unique_labels: + database_unique_labels.append(label) + + # and convert label to its (database-scope) index + database_label_idx = database_unique_labels.index(label) + + if scope > 1: # 'global' + + # update list of global-scope labels + if label not in unique_labels: + unique_labels.append(label) + # add class to the list of classes: + if label not in self._train: + self._train[label] = list() + self._train[label].append(file_id) + # and convert label to its (global-scope) index + global_label_idx = unique_labels.index(label) + + # basic segmentation protocols do not provide "scope" information + # as classes are global by definition + + else: + try: + file_label_idx = ( + database_label_idx + ) = global_label_idx = classes.index(label) + except ValueError: + # skip labels that are not in the list of classes + continue + + annotations.append( + ( + file_id, # index of file + segment.start, # start time + segment.end, # end time + file_label_idx, # file-scope label index + database_label_idx, # database-scope label index + global_label_idx, # global-scope index + ) + ) + + # since not all metadata keys are present in all files, fallback to -1 when a key is missing + metadata = [ + tuple(metadatum.get(key, -1) for key in metadata_unique_values) + for metadatum in metadata + ] + dtype = [(key, "i") for key in metadata_unique_values] + self.metadata = np.array(metadata, dtype=dtype) + + # NOTE: read with str(self.audios[file_id], encoding='utf-8') + self.audios = np.array(audios, dtype=np.string_) + + # turn list of files metadata into a single numpy array + # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 + + dtype = [ + ("sample_rate", "i"), + ("num_frames", "i"), + ("num_channels", "i"), + ("bits_per_sample", "i"), + ] + self.audio_infos = np.array(audio_infos, dtype=dtype) + self.audio_encodings = np.array(audio_encodings, dtype=np.string_) + + self.annotated_duration = np.array(annotated_duration) + + # turn list of annotated regions into a single numpy array + dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] + self.annotated_regions = np.array(annotated_regions, dtype=dtype) + + # convert annotated_classes (which is a list of list of classes, one list of classes per file) + # into a single (num_files x num_classes) numpy array: + # * True indicates that this particular class was annotated for this particular file (though it may not be active in this file) + # * False indicates that this particular class was not even annotated (i.e. its absence does not imply that it is not active in this file) + if isinstance(self.protocol, SegmentationProtocol) and self.classes is None: + self.classes = classes + self.annotated_classes = np.zeros( + (len(annotated_classes), len(self.classes)), dtype=np.bool_ + ) + for file_id, classes in enumerate(annotated_classes): + self.annotated_classes[file_id, classes] = True + + # turn list of annotations into a single numpy array + dtype = [ + ("file_id", "i"), + ("start", "f"), + ("end", "f"), + ("file_label_idx", "i"), + ("database_label_idx", "i"), + ("global_label_idx", "i"), + ] + self.annotations = np.array(annotations, dtype=dtype) + + self.metadata_unique_values = metadata_unique_values + + if not self.has_validation: + return + + validation_chunks = list() + + # obtain indexes of files in the validation subset + validation_file_ids = np.where( + self.metadata["subset"] == Subsets.index("development") + )[0] + + # iterate over files in the validation subset + for file_id in validation_file_ids: + + # get annotated regions in file + annotated_regions = self.annotated_regions[ + self.annotated_regions["file_id"] == file_id + ] + + # iterate over annotated regions + for annotated_region in annotated_regions: + + # number of chunks in annotated region + num_chunks = round(annotated_region["duration"] // duration) + + # iterate over chunks + for c in range(num_chunks): + start_time = annotated_region["start"] + c * duration + validation_chunks.append((file_id, start_time, duration)) + + dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] + self.validation_chunks = np.array(validation_chunks, dtype=dtype) + + speaker_diarization = Specifications( + duration=self.duration, + resolution=Resolution.FRAME, + problem=Problem.MONO_LABEL_CLASSIFICATION, + permutation_invariant=True, + classes=[f"speaker{i+1}" for i in range(self.max_speaker_per_chunk)], + powerset_max_classes=self.max_speakers_per_frame, + ) + speaker_embedding = Specifications( + duration=self.duration, + resolution=Resolution.CHUNK, + problem=Problem.REPRESENTATION, + classes=sorted(self._train), + ) + + self.specifications = (speaker_diarization, speaker_embedding) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): + """Prepare chunk + + Parameters + ---------- + file_id : int + File index + start_time : float + Chunk start time + duration : float + Chunk duration. + + Returns + ------- + sample : dict + Dictionary containing the chunk data with the following keys: + - `X`: waveform + - `y`: target as a SlidingWindowFeature instance where y.labels is + in meta.scope space. + - `meta`: + - `scope`: target scope (0: file, 1: database, 2: global) + - `database`: database index + - `file`: file index + """ + + file = self.get_file(file_id) + + # get label scope + label_scope = Scopes[self.metadata[file_id]["scope"]] + label_scope_key = f"{label_scope}_label_idx" + + # + chunk = Segment(start_time, start_time + duration) + + sample = dict() + sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + + # use model introspection to predict how many frames it will output + # TODO: this should be cached + num_samples = sample["X"].shape[1] + num_frames, _ = self.model.introspection(num_samples) + resolution = duration / num_frames + frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) + + # gather all annotations of current file + annotations = self.annotations[self.annotations["file_id"] == file_id] + + # gather all annotations with non-empty intersection with current chunk + chunk_annotations = annotations[ + (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) + ] + + # discretize chunk annotations at model output resolution + start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start + start_idx = np.floor(start / resolution).astype(int) + end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start + end_idx = np.ceil(end / resolution).astype(int) + + # get list and number of labels for current scope + labels = list(np.unique(chunk_annotations[label_scope_key])) + num_labels = len(labels) + + if num_labels > self.max_speakers_per_chunk: + pass + + # initial frame-level targets + y = np.zeros((num_frames, num_labels), dtype=np.uint8) + + # map labels to indices + mapping = {label: idx for idx, label in enumerate(labels)} + + for start, end, label in zip( + start_idx, end_idx, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + y[start:end, mapped_label] = 1 + + sample["y"] = SlidingWindowFeature(y, frames, labels=labels) + + metadata = self.metadata[file_id] + sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} + sample["meta"]["file"] = file_id + + return sample + + def train__iter__helper(self, rng : random.Random, **filters): + """Iterate over training samples with optional domain filtering + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key : value} dict), filter training files so that + only file such as file [key] == value are used for generating chunks + + Yields + ------ + chunk : dict + Training chunks + """ + + # indices of trainijng files that matches domain filters + + # select file dataset (embedding or diarization) according to the probability + # to ratio between these two kind of dataset: + + training = self.metadata["subset"] == Subsets.index("train") + for key, value in filters.items(): + training &= self.metadata[key] == value + file_ids = np.where(training)[0] + annotated_duration = self.annotated_duration[file_ids] + prob_annotated_duration = annotated_duration / np.sum(annotated_duration) + + duration = self.duration + + embedding_classes = list(self.specifications[Subtasks.index("embedding")]) + embedding_class_idx = 0 + + file_ids_diarization = file_ids[np.in1d(file_ids, self.diarization_database_files)] + + while True: + print("here") + # select one file at random (wiht probability proportional to its annotated duration) + # according to the ratio bewteen embedding and diarization dataset + if np.random.uniform() < self.database_ratio: + subtask = Subtasks.index("diarization") + file_id = np.random.choice(file_ids_diarization, p=prob_annotated_duration) + else: + subtask = Subtasks.index("embedding") + # shuffle embedding classes list and go through this shuffled list + # to make sure to see all the speakers during training + if embedding_class_idx == len(embedding_classes): + rng.shuffle(embedding_classes) + embedding_class_idx = 0 + # get files id for current class and sample one of these files + class_files_ids = self._train[embedding_classes[embedding_class_idx]] + embedding_class_idx += 1 + file_id = np.random.choice(class_files_ids) + + + # find indices of annotated regions in this file + annotated_region_indices = np.where( + self.annotated_regions["file_id"] == file_id + )[0] + + # turn annotated regions duration into a probability distribution + prob_annotaded_regions_duration = self.annotated_regions["duration"][ + annotated_region_indices + ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) + + # seletect one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice(annotated_region_indices, p=prob_annotaded_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + sample = self.prepare_chunk(file_id, start_time, duration) + sample["task"] = subtask + yield sample + + def train__iter__(self): + """Iterate over trainig samples + + Yields + ------ + dict: + x: (time, channel) + Audio chunks. + task: "diarization" or "embedding" + y: target speaker label for speaker embedding task, + (frame, ) frame-level targets for speaker diarization task. + Note that frame < time. + `frame is infered automagically from the exemple model output` + """ + + # create worker-specific random number generator + rng = create_rng_for_worker(self.model.current_epoch) + + balance = getattr(self, "balance", None) + if balance is None: + chunks = self.train__iter__helper(rng) + else: + # create + subchunks = dict() + for product in itertools.product([self.metadata_unique_values[key] for key in balance]): + filters = {key : value for key, value in zip(balance, product)} + subchunks[product] = self.train__iter__helper(rng, **filters) + + while True: + # select one subchunck generator at random (with uniform probability) + # so thath it is balanced on average + if balance is not None: + chunks = subchunks[rng.choice(subchunks)] + + # generate random chunk + print(chunks) + yield next(chunks) + + def segmentation_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor = None, + ) -> torch.Tensor: + """Permutation-invariant segmentation loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Permutated speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + seg_loss : torch.Tensor + Permutation-invariant segmentation loss + """ + + if self.specifications[Subtasks.index("diarization")].powerset: + + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss( + permutated_prediction, + torch.argmax(target, dim=-1), + class_weight=class_weight, + weight=weight, + ) + else: + seg_loss = binary_cross_entropy( + permutated_prediction, target.float(), weight=weight + ) + + return seg_loss + + def voice_activity_detection_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor = None, + ) -> torch.Tensor: + """Voice activity detection loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + vad_loss : torch.Tensor + Voice activity detection loss. + """ + + vad_prediction, _ = torch.max(permutated_prediction, dim=2, keepdim=True) + # (batch_size, num_frames, 1) + + vad_target, _ = torch.max(target.float(), dim=2, keepdim=False) + # (batch_size, num_frames) + + if self.vad_loss == "bce": + loss = binary_cross_entropy(vad_prediction, vad_target, weight=weight) + + elif self.vad_loss == "mse": + loss = mse_loss(vad_prediction, vad_target, weight=weight) + + return loss + + def setup_loss_func(self): + diarization_spec = self.specifications[Subtasks.index("diarization")] + if diarization_spec.powerset: + self.model.powerset = Powerset( + len(diarization_spec.classes), + diarization_spec.powerset_max_classes, + ) + + def compute_diarization_loss(self, batch : torch.Tensor): + """""" + X, y = batch["X"], batch["y"] + # drop samples that contain too many speakers + num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) + keep : torch.Tensor = num_speakers <= self.max_speaker_per_chunk + target = target[keep] + waveform = waveform[keep] + + # log effective batch size + self.model.log( + f"{self.logging_prefix}BatchSize", + keep.sum(), + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + reduce_fx="mean", + ) + # corner case + if not keep.any(): + return None + + # forward pass + prediction = self.model(waveform) + batch_size, num_frames, _ = prediction.shape + # (batch_size, num_frames, num_classes) + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + # warm-up + warm_up_left = round(self.warm_up[0] / self.duration * num_frames) + weight[:, :warm_up_left] = 0.0 + warm_up_right = round(self.warm_up[1] / self.duration * num_frames) + weight[:, num_frames - warm_up_right :] = 0.0 + + if self.specifications[Subtasks.index("diarization")].powerset: + + powerset = torch.nn.functional.one_hot( + torch.argmax(prediction, dim=-1), + self.model.powerset.num_powerset_classes, + ).float() + multilabel = self.model.powerset.to_multilabel(powerset) + permutated_target, _ = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + permutated_prediction, _ = permutate(target, prediction) + seg_loss = self.segmentation_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + f"{self.logging_prefix}TrainSegLoss", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.vad_loss is None: + vad_loss = 0.0 + + else: + + # TODO: vad_loss probably does not make sense in powerset mode + # because first class (empty set of labels) does exactly this... + if self.specifications[Subtasks.index("diarization")].powerset: + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + vad_loss = self.voice_activity_detection_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + f"{self.logging_prefix}TrainVADLoss", + vad_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = seg_loss + vad_loss + # skip batch if something went wrong for some reason + if torch.isnan(loss): + return None + + self.model.log( + f"{self.logging_prefix}TrainLoss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return {"loss": loss} + + def compute_embedding_loss(self, batch : torch.Tensor): + X, y = batch["X", batch["y"]] + loss = self.model.loss_func(self.model(X), y) + + # skip batch if something went wrong for some reason + if torch.isnan(loss): + return None + + self.model.log( + f"{self.logging_prefix}TrainLoss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return {"loss": loss} + + def training_step(self, batch, batch_idx: int): + """Compute loss for the joint task + + Parameters + ---------- + batch : (usually) dict of torch.Tensor + current batch. + batch_idx: int + Batch index. + + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + + alpha = 0.5 + if batch["task"] == "diarization": + # compute diarization loss + diarization_loss = self.compute_diarization_loss(batch=batch) + if batch["task"] == "embedding": + # compute embedding loss + embedding_loss = self.compute_embedding_loss(batch=batch) + loss = alpha * diarization_loss + (1 - alpha) * embedding_loss + + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components for diarization subtask, + and equal error rate for the embedding part + """ + + if self.specifications[Subtasks.index("diarization")].powerset: + return { + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), + "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), + } + + return { + "DiarizationErrorRate": OptimalDiarizationErrorRate(), + "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), + "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), + "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), + "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), + "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), + } From 58599c9cb4cd54f52d1173a59e67178f81779ec7 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 10:50:05 +0200 Subject: [PATCH 06/83] update `train__iter__helper` method of the joint task - fixes the dimension error between files id and probabilties arrays - changes the way of how chunks for the embedding task are sampled - creates two functions to draw chunks, one for each subtask Tests are required to ensure that there are no bugs --- .../speaker_diarization_and_embedding.py | 158 ++++++++++++------ 1 file changed, 111 insertions(+), 47 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 1bf5dae36..cec2e7343 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -25,7 +25,7 @@ import random import numpy as np import torch -from typing import Literal, Union, Sequence, Dict +from typing import Literal, List, Text, Union, Sequence, Dict import warnings from torch_audiomentations.core.transforms_interface import BaseWaveformTransform @@ -113,8 +113,7 @@ def __init__( # * diarization databases are those with file or database speaker label scope self.embedding_database_files = [] self.diarization_database_files = [] - self._train = {} - + def get_file(self, file_id): file = dict() @@ -205,8 +204,8 @@ def setup(self, stage="fit"): # keep track of speaker label scope (file, database, or global) for speaker diarization protocols if isinstance(self.protocol, SpeakerDiarizationProtocol): metadatum["scope"] = Scopes.index(file["scope"]) - # add the file to the embedding or diarization list according to the file database speaker - # labels scope + # add the file to the embedding or diarization list according to the file database speaker + # labels scope if file["scope"] == 'global': self.embedding_database_files.append(file_id) elif file["scope"] in ["database", "file"]: @@ -363,10 +362,6 @@ def setup(self, stage="fit"): # update list of global-scope labels if label not in unique_labels: unique_labels.append(label) - # add class to the list of classes: - if label not in self._train: - self._train[label] = list() - self._train[label].append(file_id) # and convert label to its (global-scope) index global_label_idx = unique_labels.index(label) @@ -491,7 +486,7 @@ def setup(self, stage="fit"): duration=self.duration, resolution=Resolution.CHUNK, problem=Problem.REPRESENTATION, - classes=sorted(self._train), + classes=unique_labels, ) self.specifications = (speaker_diarization, speaker_embedding) @@ -581,6 +576,88 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): return sample + def draw_diarization_chunk(self, file_ids : np.ndarray, + prob_annotated_duration : np.ndarray, + rng : random.Random, + duration : float, + ) -> tuple: + """Sample one chunk for the diarization task + + Parameters + ---------- + file_ids: np.ndarray + array containing files id + prob_annotated_duration: np.ndarray + array of the same size than file_ids array, containing probability + to corresponding file to be drawn + rng : random.Random + Random number generator + duration: float + duration of the chunk to draw + """ + # select one file at random (wiht probability proportional to its annotated duration) + file_id = np.random.choice(file_ids, p=prob_annotated_duration) + # find indices of annotated regions in this file + annotated_region_indices = np.where( + self.annotated_regions["file_id"] == file_id + )[0] + + # turn annotated regions duration into a probability distribution + prob_annotaded_regions_duration = self.annotated_regions["duration"][ + annotated_region_indices + ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) + + # seletect one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice(annotated_region_indices, + p=prob_annotaded_regions_duration + ) + + # select one chunk at random in this annotated region + _, _, start, end = self.annotated_regions[annotated_region_index] + start_time = rng.uniform(start, end - duration) + + return (file_id, start_time) + + def draw_embedding_chunk(self, klass : Text, + classes : List[Text], + duration : float) -> tuple: + """Sample one chunk for the embedding task + + Parameters + ---------- + klass: Text + current class of speakers from which to draw a sample + classes: List[Text] + list of all the global speaker labels, in the same order than the list + defined in the task specification + duration: float + duration of the chunk to draw + + Return + ------ + tuple: + file_id: + the file id to which the sampled chunk belongs + start_time: + start time of the sampled chunk + """ + # get index of the current class in the order of original class list + class_id = classes.index(klass) + # get segments for current class + class_segments_idx = self.annotations["global_label_idx"] == class_id + class_segments = self.annotations[class_segments_idx] + + # sample one segment from all the class segments: + segments_duration = class_segments["end"] - class_segments["start"] + segments_total_duration = np.sum(segments_duration) + prob_segments = segments_duration / segments_total_duration + segment = np.random.choice(class_segments, p=prob_segments) + + # sample chunk start time in order to intersect it with the sampled segment + start_time = np.random.uniform(segment["start"] - duration / 2, segment["start"]) + + return (segment["file_id"], start_time) + def train__iter__helper(self, rng : random.Random, **filters): """Iterate over training samples with optional domain filtering @@ -591,7 +668,7 @@ def train__iter__helper(self, rng : random.Random, **filters): filters : dict, optional When provided (as {key : value} dict), filter training files so that only file such as file [key] == value are used for generating chunks - + Yields ------ chunk : dict @@ -607,53 +684,40 @@ def train__iter__helper(self, rng : random.Random, **filters): for key, value in filters.items(): training &= self.metadata[key] == value file_ids = np.where(training)[0] + # get the subset of embedding database files from training files + embedding_files_ids = file_ids[np.in1d(file_ids, self.embedding_database_files)] + annotated_duration = self.annotated_duration[file_ids] prob_annotated_duration = annotated_duration / np.sum(annotated_duration) - + # set probability to sample a file from embedding database to 0 + prob_annotated_duration[embedding_files_ids] = 0 + duration = self.duration - embedding_classes = list(self.specifications[Subtasks.index("embedding")]) + # make a copy of the original classes list, in order to not modify it when shuffling* + embedding_classes = self.specifications[Subtasks.index("embedding")].classes + shuffled_embedding_classes = list(embedding_classes) embedding_class_idx = 0 - file_ids_diarization = file_ids[np.in1d(file_ids, self.diarization_database_files)] - while True: - print("here") - # select one file at random (wiht probability proportional to its annotated duration) - # according to the ratio bewteen embedding and diarization dataset + # choose between diarization or embedding subtask according to a ratio + # between these two tasks if np.random.uniform() < self.database_ratio: subtask = Subtasks.index("diarization") - file_id = np.random.choice(file_ids_diarization, p=prob_annotated_duration) + file_id, start_time = self.draw_diarization_chunk(file_ids, prob_annotated_duration, rng, duration) else: subtask = Subtasks.index("embedding") # shuffle embedding classes list and go through this shuffled list # to make sure to see all the speakers during training - if embedding_class_idx == len(embedding_classes): - rng.shuffle(embedding_classes) + if embedding_class_idx == len(shuffled_embedding_classes): + rng.shuffle(shuffled_embedding_classes) embedding_class_idx = 0 - # get files id for current class and sample one of these files - class_files_ids = self._train[embedding_classes[embedding_class_idx]] + klass = shuffled_embedding_classes[embedding_class_idx] embedding_class_idx += 1 - file_id = np.random.choice(class_files_ids) - - - # find indices of annotated regions in this file - annotated_region_indices = np.where( - self.annotated_regions["file_id"] == file_id - )[0] - - # turn annotated regions duration into a probability distribution - prob_annotaded_regions_duration = self.annotated_regions["duration"][ - annotated_region_indices - ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) - - # seletect one annotated region at random (with probability proportional to its duration) - annotated_region_index = np.random.choice(annotated_region_indices, p=prob_annotaded_regions_duration - ) - - # select one chunk at random in this annotated region - _, _, start, end = self.annotated_regions[annotated_region_index] - start_time = rng.uniform(start, end - duration) + file_id, start_time = self.draw_embedding_chunk(klass, + classes=embedding_classes, + duration=duration) + sample = self.prepare_chunk(file_id, start_time, duration) sample["task"] = subtask yield sample @@ -807,7 +871,7 @@ def compute_diarization_loss(self, batch : torch.Tensor): # corner case if not keep.any(): return None - + # forward pass prediction = self.model(waveform) batch_size, num_frames, _ = prediction.shape @@ -917,14 +981,14 @@ def compute_embedding_loss(self, batch : torch.Tensor): def training_step(self, batch, batch_idx: int): """Compute loss for the joint task - + Parameters ---------- batch : (usually) dict of torch.Tensor current batch. batch_idx: int Batch index. - + Returns ------- loss : {str: torch.tensor} @@ -943,7 +1007,7 @@ def training_step(self, batch, batch_idx: int): def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: - """Returns diarization error rate and its components for diarization subtask, + """Returns diarization error rate and its components for diarization subtask, and equal error rate for the embedding part """ From 04de82fe237f778db68be2c2ae0e6b3baded3a2d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 14:49:14 +0200 Subject: [PATCH 07/83] fix `StopIteration` error --- .../speaker_diarization_and_embedding.py | 65 +++++++++---------- 1 file changed, 29 insertions(+), 36 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index cec2e7343..0ddb19c8f 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -86,7 +86,7 @@ def __init__( self, protocol, duration: float = 5.0, - max_speaker_per_chunk: int = 3, + max_speakers_per_chunk: int = 3, max_speakers_per_frame: int = 2, batch_size: int = 32, database_ratio : float = 0.5, @@ -103,7 +103,7 @@ def __init__( augmentation=augmentation, ) - self.max_speaker_per_chunk = max_speaker_per_chunk + self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame self.database_ratio = database_ratio @@ -155,7 +155,6 @@ def setup(self, stage="fit"): metadata_unique_values = defaultdict(list) metadata_unique_values["subset"] = Subsets - metadata_unique_values["subtask"] = ["diarization", "embedding"] if isinstance(self.protocol, SpeakerDiarizationProtocol): metadata_unique_values["scope"] = Scopes @@ -479,7 +478,7 @@ def setup(self, stage="fit"): resolution=Resolution.FRAME, problem=Problem.MONO_LABEL_CLASSIFICATION, permutation_invariant=True, - classes=[f"speaker{i+1}" for i in range(self.max_speaker_per_chunk)], + classes=[f"speaker{i+1}" for i in range(self.max_speakers_per_chunk)], powerset_max_classes=self.max_speakers_per_frame, ) speaker_embedding = Specifications( @@ -491,7 +490,7 @@ def setup(self, stage="fit"): self.specifications = (speaker_diarization, speaker_embedding) - def prepare_chunk(self, file_id: int, start_time: float, duration: float): + def prepare_chunk(self, file_id: int, start_time: float, duration: float, subtask: int): """Prepare chunk Parameters @@ -502,6 +501,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): Chunk start time duration : float Chunk duration. + subtask: int + - 0 : diarization task + - 1 : embedding task Returns ------- @@ -528,13 +530,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -545,9 +540,10 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + # TODO handle tuple outputs from the model + start_idx = np.floor(start / self.model.example_output[0].frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output[0].frames.step).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -557,7 +553,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((num_frames, num_labels), dtype=np.uint8) + y = np.zeros((self.model.example_output[0].num_frames, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -568,11 +564,14 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): mapped_label = mapping[label] y[start:end, mapped_label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=labels) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output[subtask].frames, labels=labels + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id + sample["task"] = subtask return sample @@ -582,7 +581,7 @@ def draw_diarization_chunk(self, file_ids : np.ndarray, duration : float, ) -> tuple: """Sample one chunk for the diarization task - + Parameters ---------- file_ids: np.ndarray @@ -622,7 +621,7 @@ def draw_embedding_chunk(self, klass : Text, classes : List[Text], duration : float) -> tuple: """Sample one chunk for the embedding task - + Parameters ---------- klass: Text @@ -632,7 +631,7 @@ def draw_embedding_chunk(self, klass : Text, defined in the task specification duration: float duration of the chunk to draw - + Return ------ tuple: @@ -676,10 +675,6 @@ def train__iter__helper(self, rng : random.Random, **filters): """ # indices of trainijng files that matches domain filters - - # select file dataset (embedding or diarization) according to the probability - # to ratio between these two kind of dataset: - training = self.metadata["subset"] == Subsets.index("train") for key, value in filters.items(): training &= self.metadata[key] == value @@ -694,7 +689,7 @@ def train__iter__helper(self, rng : random.Random, **filters): duration = self.duration - # make a copy of the original classes list, in order to not modify it when shuffling* + # make a copy of the original classes list, in order to not modify it when shuffling embedding_classes = self.specifications[Subtasks.index("embedding")].classes shuffled_embedding_classes = list(embedding_classes) embedding_class_idx = 0 @@ -717,9 +712,8 @@ def train__iter__helper(self, rng : random.Random, **filters): file_id, start_time = self.draw_embedding_chunk(klass, classes=embedding_classes, duration=duration) - - sample = self.prepare_chunk(file_id, start_time, duration) - sample["task"] = subtask + + sample = self.prepare_chunk(file_id, start_time, duration, subtask) yield sample def train__iter__(self): @@ -750,15 +744,14 @@ def train__iter__(self): filters = {key : value for key, value in zip(balance, product)} subchunks[product] = self.train__iter__helper(rng, **filters) - while True: - # select one subchunck generator at random (with uniform probability) - # so thath it is balanced on average - if balance is not None: - chunks = subchunks[rng.choice(subchunks)] + while True: + # select one subchunck generator at random (with uniform probability) + # so thath it is balanced on average + if balance is not None: + chunks = subchunks[rng.choice(subchunks)] - # generate random chunk - print(chunks) - yield next(chunks) + # generate random chunk + yield next(chunks) def segmentation_loss( self, @@ -854,7 +847,7 @@ def compute_diarization_loss(self, batch : torch.Tensor): X, y = batch["X"], batch["y"] # drop samples that contain too many speakers num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) - keep : torch.Tensor = num_speakers <= self.max_speaker_per_chunk + keep : torch.Tensor = num_speakers <= self.max_speakers_per_chunk target = target[keep] waveform = waveform[keep] From d8cb598491d7e951f8c90a691f70aeb2f99f81a5 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 15:40:43 +0200 Subject: [PATCH 08/83] add missing collate methods For now this is a copy past from methods in segmentation task. --- .../speaker_diarization_and_embedding.py | 104 +++++++++++++++++- 1 file changed, 103 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 0ddb19c8f..fc08a4d74 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -32,6 +32,7 @@ from torchaudio.backend.common import AudioMetaData from torchmetrics import Metric from torchmetrics.classification import BinaryAUROC +from torch.utils.data._utils.collate import default_collate from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pyannote.audio.core.task import Problem, Resolution, Specifications, Task @@ -571,7 +572,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float, subtas metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id - sample["task"] = subtask + sample["meta"]["subtask"] = subtask return sample @@ -753,6 +754,107 @@ def train__iter__(self): # generate random chunk yield next(chunks) + def collate_X(self, batch) -> torch.Tensor: + """Collate for data""" + return default_collate([b["X"] for b in batch]) + + def collate_y(self, batch) -> torch.Tensor: + """ + Parameters + ---------- + batch : list + List of samples to collate. + "y" field is expected to be a SlidingWindowFeature. + + Returns + ------- + y : torch.Tensor + Collated target tensor of shape (num_frames, self.max_speakers_per_chunk) + If one chunk has more than `self.max_speakers_per_chunk` speakers, we keep + the max_speakers_per_chunk most talkative ones. If it has less, we pad with + zeros (artificial inactive speakers). + """ + + collated_y = [] + for b in batch: + y = b["y"].data + num_speakers = len(b["y"].labels) + if num_speakers > self.max_speakers_per_chunk: + # sort speakers in descending talkativeness order + indices = np.argsort(-np.sum(y, axis=0), axis=0) + # keep only the most talkative speakers + y = y[:, indices[: self.max_speakers_per_chunk]] + + # TODO: we should also sort the speaker labels in the same way + + elif num_speakers < self.max_speakers_per_chunk: + # create inactive speakers by zero padding + y = np.pad( + y, + ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), + mode="constant", + ) + + else: + # we have exactly the right number of speakers + pass + + collated_y.append(y) + + return torch.from_numpy(np.stack(collated_y)) + + def collate_meta(self, batch) -> torch.Tensor: + """Collate for metadata""" + return default_collate([b["meta"] for b in batch]) + + def collate_fn(self, batch, stage="train"): + """Collate function used for most segmentation tasks + + This function does the following: + * stack waveforms into a (batch_size, num_channels, num_samples) tensor batch["X"]) + * apply augmentation when in "train" stage + * convert targets into a (batch_size, num_frames, num_classes) tensor batch["y"] + * collate any other keys that might be present in the batch using pytorch default_collate function + + Parameters + ---------- + batch : list of dict + List of training samples. + + Returns + ------- + batch : dict + Collated batch as {"X": torch.Tensor, "y": torch.Tensor} dict. + """ + + # collate X + collated_X = self.collate_X(batch) + + # collate y + try: + collated_y = self.collate_y(batch) + except RuntimeError as e: + print(e) + print([b["y"].data for b in batch]) + + # collate metadata + collated_meta = self.collate_meta(batch) + + # apply augmentation (only in "train" stage) + self.augmentation.train(mode=(stage == "train")) + augmented = self.augmentation( + samples=collated_X, + sample_rate=self.model.hparams.sample_rate, + targets=collated_y.unsqueeze(1), + ) + + return { + "X": augmented.samples, + "y": augmented.targets.squeeze(1), + "meta": collated_meta, + } + + def segmentation_loss( self, permutated_prediction: torch.Tensor, From d2d6e14f816e0d7d65d31b23ea02167c230af7ee Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 15:53:00 +0200 Subject: [PATCH 09/83] remove support for non-powerset mode --- .../speaker_diarization_and_embedding.py | 103 ++++++------------ 1 file changed, 35 insertions(+), 68 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index fc08a4d74..baf7e1246 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -878,24 +878,18 @@ def segmentation_loss( Permutation-invariant segmentation loss """ - if self.specifications[Subtasks.index("diarization")].powerset: - - # `clamp_min` is needed to set non-speech weight to 1. - class_weight = ( - torch.clamp_min(self.model.powerset.cardinality, 1.0) - if self.weigh_by_cardinality - else None - ) - seg_loss = nll_loss( - permutated_prediction, - torch.argmax(target, dim=-1), - class_weight=class_weight, - weight=weight, - ) - else: - seg_loss = binary_cross_entropy( - permutated_prediction, target.float(), weight=weight - ) + # `clamp_min` is needed to set non-speech weight to 1. + class_weight = ( + torch.clamp_min(self.model.powerset.cardinality, 1.0) + if self.weigh_by_cardinality + else None + ) + seg_loss = nll_loss( + permutated_prediction, + torch.argmax(target, dim=-1), + class_weight=class_weight, + weight=weight, + ) return seg_loss @@ -938,11 +932,10 @@ def voice_activity_detection_loss( def setup_loss_func(self): diarization_spec = self.specifications[Subtasks.index("diarization")] - if diarization_spec.powerset: - self.model.powerset = Powerset( - len(diarization_spec.classes), - diarization_spec.powerset_max_classes, - ) + self.model.powerset = Powerset( + len(diarization_spec.classes), + diarization_spec.powerset_max_classes, + ) def compute_diarization_loss(self, batch : torch.Tensor): """""" @@ -985,26 +978,18 @@ def compute_diarization_loss(self, batch : torch.Tensor): warm_up_right = round(self.warm_up[1] / self.duration * num_frames) weight[:, num_frames - warm_up_right :] = 0.0 - if self.specifications[Subtasks.index("diarization")].powerset: - - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) - permutated_target, _ = permutate(multilabel, target) - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - permutated_prediction, _ = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) + powerset = torch.nn.functional.one_hot( + torch.argmax(prediction, dim=-1), + self.model.powerset.num_powerset_classes, + ).float() + multilabel = self.model.powerset.to_multilabel(powerset) + permutated_target, _ = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) self.model.log( f"{self.logging_prefix}TrainSegLoss", @@ -1022,15 +1007,9 @@ def compute_diarization_loss(self, batch : torch.Tensor): # TODO: vad_loss probably does not make sense in powerset mode # because first class (empty set of labels) does exactly this... - if self.specifications[Subtasks.index("diarization")].powerset: - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - vad_loss = self.voice_activity_detection_loss( - permutated_prediction, target, weight=weight - ) + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) self.model.log( f"{self.logging_prefix}TrainVADLoss", @@ -1105,23 +1084,11 @@ def default_metric( """Returns diarization error rate and its components for diarization subtask, and equal error rate for the embedding part """ - - if self.specifications[Subtasks.index("diarization")].powerset: - return { - "DiarizationErrorRate": DiarizationErrorRate(0.5), - "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), - "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), - "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), - "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), - } - return { - "DiarizationErrorRate": OptimalDiarizationErrorRate(), - "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), - "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), - "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), - "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), } From e58943bdafbba56b157529021c8f1dff958f1b3f Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 16:04:22 +0200 Subject: [PATCH 10/83] remove computing of vad loss as computing this loss probably does not make sense in powerset mode because first class (empty set of labels) does exactly this --- .../speaker_diarization_and_embedding.py | 60 +------------------ 1 file changed, 2 insertions(+), 58 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index baf7e1246..ca6fedc44 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -893,43 +893,6 @@ def segmentation_loss( return seg_loss - def voice_activity_detection_loss( - self, - permutated_prediction: torch.Tensor, - target: torch.Tensor, - weight: torch.Tensor = None, - ) -> torch.Tensor: - """Voice activity detection loss - - Parameters - ---------- - permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor - Speaker activity predictions. - target : (batch_size, num_frames, num_speakers) torch.Tensor - Speaker activity. - weight : (batch_size, num_frames, 1) torch.Tensor, optional - Frames weight. - - Returns - ------- - vad_loss : torch.Tensor - Voice activity detection loss. - """ - - vad_prediction, _ = torch.max(permutated_prediction, dim=2, keepdim=True) - # (batch_size, num_frames, 1) - - vad_target, _ = torch.max(target.float(), dim=2, keepdim=False) - # (batch_size, num_frames) - - if self.vad_loss == "bce": - loss = binary_cross_entropy(vad_prediction, vad_target, weight=weight) - - elif self.vad_loss == "mse": - loss = mse_loss(vad_prediction, vad_target, weight=weight) - - return loss - def setup_loss_func(self): diarization_spec = self.specifications[Subtasks.index("diarization")] self.model.powerset = Powerset( @@ -944,6 +907,7 @@ def compute_diarization_loss(self, batch : torch.Tensor): num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) keep : torch.Tensor = num_speakers <= self.max_speakers_per_chunk target = target[keep] + # TODO using variable `waveform` before assignment waveform = waveform[keep] # log effective batch size @@ -1000,27 +964,7 @@ def compute_diarization_loss(self, batch : torch.Tensor): logger=True, ) - if self.vad_loss is None: - vad_loss = 0.0 - - else: - - # TODO: vad_loss probably does not make sense in powerset mode - # because first class (empty set of labels) does exactly this... - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) - - self.model.log( - f"{self.logging_prefix}TrainVADLoss", - vad_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - loss = seg_loss + vad_loss + loss = seg_loss # skip batch if something went wrong for some reason if torch.isnan(loss): return None From bc989cdadcb511859b38fb42a8aed9ff7713a7f5 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 16:20:31 +0200 Subject: [PATCH 11/83] remove unused imports --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index ca6fedc44..ebeb7dc4c 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -36,7 +36,7 @@ from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss +from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset from pyannote.audio.utils.random import create_rng_for_worker @@ -47,11 +47,6 @@ DiarizationErrorRate, FalseAlarmRate, MissedDetectionRate, - OptimalDiarizationErrorRate, - OptimalDiarizationErrorRateThreshold, - OptimalFalseAlarmRate, - OptimalMissedDetectionRate, - OptimalSpeakerConfusionRate, SpeakerConfusionRate, ) From b4d0a7803b98e34ca7635f3357a4fee9be35b490 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 19 Jun 2023 17:55:24 +0200 Subject: [PATCH 12/83] fix probabilities do not sum to 1 error --- .../audio/tasks/joint_task/speaker_diarization_and_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index ebeb7dc4c..a77e402a4 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -679,9 +679,9 @@ def train__iter__helper(self, rng : random.Random, **filters): embedding_files_ids = file_ids[np.in1d(file_ids, self.embedding_database_files)] annotated_duration = self.annotated_duration[file_ids] + annotated_duration[embedding_files_ids] = 0 prob_annotated_duration = annotated_duration / np.sum(annotated_duration) # set probability to sample a file from embedding database to 0 - prob_annotated_duration[embedding_files_ids] = 0 duration = self.duration From 78718b1639fa554ac3009e7af3ec8a7c4f8fc7c9 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 20 Jun 2023 09:27:23 +0200 Subject: [PATCH 13/83] attempt to fix file duration error --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index a77e402a4..f941c16b9 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -520,11 +520,10 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float, subtas label_scope = Scopes[self.metadata[file_id]["scope"]] label_scope_key = f"{label_scope}_label_idx" - # chunk = Segment(start_time, start_time + duration) sample = dict() - sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration, mode="pad") # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] From dfdd8f38f307239d69f5b2434189f684f3cc123e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 20 Jun 2023 09:51:38 +0200 Subject: [PATCH 14/83] attempt to fix negative `start_time` in embedding part --- .../audio/tasks/joint_task/speaker_diarization_and_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index f941c16b9..cb16bba1f 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -648,7 +648,7 @@ def draw_embedding_chunk(self, klass : Text, segment = np.random.choice(class_segments, p=prob_segments) # sample chunk start time in order to intersect it with the sampled segment - start_time = np.random.uniform(segment["start"] - duration / 2, segment["start"]) + start_time = np.random.uniform(max(segment["start"] - duration / 2, 0), segment["start"]) return (segment["file_id"], start_time) From 18883600d6e00f2e9bd0873f9ff1cc5e7f81b733 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 20 Jun 2023 16:39:08 +0200 Subject: [PATCH 15/83] add end-to-end diarization and embedding model --- .../models/joint/end_to_end_diarization.py | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 pyannote/audio/models/joint/end_to_end_diarization.py diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py new file mode 100644 index 000000000..fc36d654b --- /dev/null +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -0,0 +1,225 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import Optional +from warnings import warn +from einops import rearrange + +import torch +from torch import nn +import torch.nn.functional as F + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.sincnet import SincNet +from pyannote.audio.models.blocks.pooling import StatsPool +from pyannote.audio.utils.params import merge_dict +from pyannote.core.utils.generators import pairwise + +class SpeakerEndToEndDiarization(Model): + """Speaker End-to-End Diarization and Embedding model + SINCNET -- TDNN .. TDNN -- TDNN ..TDNN -- StatsPool -- Linear -- Classifier + \ LSTM ... LSTM -- FeedForward -- Classifier + """ + SINCNET_DEFAULTS = {"stride": 10} + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 2, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + "batch_first": True, + } + LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + sincnet: dict = None, + lstm: dict= None, + linear: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + num_features: int = 60, + embedding_dim: int = 512, + separation_idx: int = 2, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if num_features != 60: + warn("For now, the model only support a number of features of 60. Set it to 60") + num_features = 60 + self.num_features = num_features + self.separation_idx = separation_idx + self.save_hyperparameters("num_features", "embedding_dim", "separation_idx") + + + # sincnet module + sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) + sincnet["sample_rate"] = sample_rate + self.sincnet =SincNet(**sincnet) + self.save_hyperparameters("sincnet") + + # tdnn modules + self.tdnn_blocks = nn.ModuleList() + in_channel = num_features + out_channels = [512, 512, 512, 512, 1500] + kernel_sizes = [5, 3, 3, 1, 1] + dilations = [1, 2, 3, 1, 1] + + for out_channel, kernel_size, dilation in zip( + out_channels, kernel_sizes, dilations + ): + self.tdnn_blocks.extend( + [ + nn.Sequential( + nn.Conv1d( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + dilation=dilation, + ), + nn.LeakyReLU(), + nn.BatchNorm1d(out_channel), + ), + ] + ) + in_channel = out_channel + + # lstm modules: + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + self.save_hyperparameters("lstm") + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(out_channels[separation_idx], **multi_layer_lstm) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + del one_layer_lstm["monolithic"] + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + out_channels[separation_idx] + if i == 0 + else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + **one_layer_lstm + ) + for i in range(num_layers) + ] + ) + + # linear module for the diarization part: + linear = merge_dict(self.LINEAR_DEFAULTS, linear) + self.save_hyperparameters("linear") + if linear["num_layers"] < 1: + return + + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + # stats pooling module for the embedding part: + self.stats_pool = StatsPool() + # linear module for the embedding part: + self.embedding = nn.Linear(in_channel * 2, embedding_dim) + + + + def build(self): + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + in_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + + out_features = self.specifications.num_powerset_classes + + self.classifier = nn.Linear(in_features, out_features) + self.activation = self.default_activation() + + + def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + + Parameters + ---------- + waveforms : torch.Tensor + Batch of waveforms with shape (batch, channel, sample) + weights : torch.Tensor, optional + Batch of weights wiht shape (batch, frame) + """ + common_outputs = self.sincnet(waveforms) + # (batch, features, frames) + # common part to diarization and embedding: + tdnn_idx = 0 + while tdnn_idx <= self.separation_idx: + common_outputs = self.tdnn_blocks[tdnn_idx](common_outputs) + tdnn_idx = tdnn_idx + 1 + # diarization part: + if self.hparams.lstm["monolithic"]: + diarization_outputs, _ = self.lstm( + rearrange(common_outputs, "batch feature frame -> batch frame feature") + ) + else: + diarization_outputs = rearrange(common_outputs, "batch feature frame -> batch frame feature") + for i, lstm in enumerate(self.lstm): + diarization_outputs, _ = lstm(diarization_outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + diarization_outputs = self.linear() + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + diarization_outputs = F.leaky_relu(linear(diarization_outputs)) + diarization_outputs = self.classifier(diarization_outputs) + diarization_outputs = self.activation(diarization_outputs) + + # embedding part: + embedding_outputs = torch.clone(common_outputs) + for tdnn_block in self.tdnn_blocks[tdnn_idx:]: + embedding_outputs = tdnn_block(embedding_outputs) + + # TODO : reinject diarization outputs into the pooling layers: + embedding_outputs = self.stats_pool(embedding_outputs) + embedding_outputs = self.embedding(embedding_outputs) + + return (diarization_outputs, embedding_outputs) From 6216d1f7cc70ea5caf1a27930c6e644d6ce5f95a Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 21 Jun 2023 22:03:50 +0200 Subject: [PATCH 16/83] update end-to-end model --- pyannote/audio/models/joint/__init__.py | 25 ++++++++++++ .../models/joint/end_to_end_diarization.py | 40 ++++++++++++++----- .../speaker_diarization_and_embedding.py | 11 ++--- 3 files changed, 58 insertions(+), 18 deletions(-) create mode 100644 pyannote/audio/models/joint/__init__.py diff --git a/pyannote/audio/models/joint/__init__.py b/pyannote/audio/models/joint/__init__.py new file mode 100644 index 000000000..3c6230fdb --- /dev/null +++ b/pyannote/audio/models/joint/__init__.py @@ -0,0 +1,25 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .end_to_end_diarization import SpeakerEndToEndDiarization + +__all__ = ["SpeakerEndToEndDiarization"] diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index fc36d654b..e97e449e4 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020 CNRS +# Copyright (c) 2023 CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,21 +20,29 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional +from typing import Literal, Optional from warnings import warn from einops import rearrange import torch from torch import nn import torch.nn.functional as F +from pytorch_metric_learning.losses import ArcFaceLoss from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.models.blocks.pooling import StatsPool from pyannote.audio.utils.params import merge_dict +from pyannote.audio.utils.powerset import Powerset from pyannote.core.utils.generators import pairwise + +# TODO deplace these two lines into uitls/multi_task +Subtask = Literal["diarization", "embedding"] +Subtasks = list(Subtask.__args__) + + class SpeakerEndToEndDiarization(Model): """Speaker End-to-End Diarization and Embedding model SINCNET -- TDNN .. TDNN -- TDNN ..TDNN -- StatsPool -- Linear -- Classifier @@ -172,10 +180,21 @@ def build(self): 2 if self.hparams.lstm["bidirectional"] else 1 ) - out_features = self.specifications.num_powerset_classes - + diarization_spec = self.specifications[Subtasks.index("diarization")] + out_features = diarization_spec.num_powerset_classes self.classifier = nn.Linear(in_features, out_features) - self.activation = self.default_activation() + + self.powerset = Powerset( + len(diarization_spec.classes), + diarization_spec.powerset_max_classes, + ) + + self.arc_face_loss = ArcFaceLoss( + len(self.specifications[Subtasks.index("embedding")].classes), + self.hparams["embedding_dim"], + margin=self.task.margin, + scale=self.task.scale, + ) def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -211,15 +230,14 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non for linear in self.linear: diarization_outputs = F.leaky_relu(linear(diarization_outputs)) diarization_outputs = self.classifier(diarization_outputs) - diarization_outputs = self.activation(diarization_outputs) - + diarization_outputs = F.log_softmax(diarization_outputs, dim=-1) + weights = self.powerset(diarization_outputs).transpose(1, 2) + # embedding part: embedding_outputs = torch.clone(common_outputs) for tdnn_block in self.tdnn_blocks[tdnn_idx:]: embedding_outputs = tdnn_block(embedding_outputs) - - # TODO : reinject diarization outputs into the pooling layers: - embedding_outputs = self.stats_pool(embedding_outputs) + embedding_outputs = self.stats_pool(embedding_outputs, weights=weights) embedding_outputs = self.embedding(embedding_outputs) - + return (diarization_outputs, embedding_outputs) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index cb16bba1f..92ee12548 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -88,6 +88,8 @@ def __init__( database_ratio : float = 0.5, num_workers: int = None, pin_memory: bool = False, + margin : float = 28.6, + scale: float = 64.0, augmentation: BaseWaveformTransform = None ) -> None: super().__init__( @@ -102,6 +104,8 @@ def __init__( self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame self.database_ratio = database_ratio + self.margin = margin + self.scale = scale # keep track of the use of database available in the meta protocol @@ -887,13 +891,6 @@ def segmentation_loss( return seg_loss - def setup_loss_func(self): - diarization_spec = self.specifications[Subtasks.index("diarization")] - self.model.powerset = Powerset( - len(diarization_spec.classes), - diarization_spec.powerset_max_classes, - ) - def compute_diarization_loss(self, batch : torch.Tensor): """""" X, y = batch["X"], batch["y"] From b42cc33011f01772fbde17cbeb54d4bbce360644 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 21 Jun 2023 22:15:05 +0200 Subject: [PATCH 17/83] clean multi-task source code --- .../speaker_diarization_and_embedding.py | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 92ee12548..cf230773a 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020- CNRS +# Copyright (c) 2023- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,10 +23,10 @@ from collections import defaultdict import itertools import random -import numpy as np -import torch from typing import Literal, List, Text, Union, Sequence, Dict import warnings +import numpy as np +import torch from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchaudio.backend.common import AudioMetaData @@ -34,11 +34,10 @@ from torchmetrics.classification import BinaryAUROC from torch.utils.data._utils.collate import default_collate -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate -from pyannote.audio.utils.powerset import Powerset from pyannote.audio.utils.random import create_rng_for_worker from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset @@ -673,7 +672,7 @@ def train__iter__helper(self, rng : random.Random, **filters): Training chunks """ - # indices of trainijng files that matches domain filters + # indices of training files that matches domain filters training = self.metadata["subset"] == Subsets.index("train") for key, value in filters.items(): training &= self.metadata[key] == value @@ -682,14 +681,16 @@ def train__iter__helper(self, rng : random.Random, **filters): embedding_files_ids = file_ids[np.in1d(file_ids, self.embedding_database_files)] annotated_duration = self.annotated_duration[file_ids] + # set duration of files for the embedding part to zero, in order to not + # drawn them for diarization part annotated_duration[embedding_files_ids] = 0 prob_annotated_duration = annotated_duration / np.sum(annotated_duration) - # set probability to sample a file from embedding database to 0 duration = self.duration - # make a copy of the original classes list, in order to not modify it when shuffling + # make a copy of the original class list, so as not to modify it during shuffling embedding_classes = self.specifications[Subtasks.index("embedding")].classes + # use original order for the first run of the shuffled classes list: shuffled_embedding_classes = list(embedding_classes) embedding_class_idx = 0 @@ -698,7 +699,9 @@ def train__iter__helper(self, rng : random.Random, **filters): # between these two tasks if np.random.uniform() < self.database_ratio: subtask = Subtasks.index("diarization") - file_id, start_time = self.draw_diarization_chunk(file_ids, prob_annotated_duration, rng, duration) + file_id, start_time = self.draw_diarization_chunk(file_ids, prob_annotated_duration, + rng, + duration) else: subtask = Subtasks.index("embedding") # shuffle embedding classes list and go through this shuffled list @@ -708,9 +711,7 @@ def train__iter__helper(self, rng : random.Random, **filters): embedding_class_idx = 0 klass = shuffled_embedding_classes[embedding_class_idx] embedding_class_idx += 1 - file_id, start_time = self.draw_embedding_chunk(klass, - classes=embedding_classes, - duration=duration) + file_id, start_time = self.draw_embedding_chunk(klass, embedding_classes, duration) sample = self.prepare_chunk(file_id, start_time, duration, subtask) yield sample @@ -827,13 +828,8 @@ def collate_fn(self, batch, stage="train"): # collate X collated_X = self.collate_X(batch) - # collate y - try: - collated_y = self.collate_y(batch) - except RuntimeError as e: - print(e) - print([b["y"].data for b in batch]) + collated_y = self.collate_y(batch) # collate metadata collated_meta = self.collate_meta(batch) From 3d295dde9adba33ef32442bf26b106a0a70a96ef Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 21 Jun 2023 22:25:10 +0200 Subject: [PATCH 18/83] remove support for `SegmentationProtocol` in the multi-tasks --- .../speaker_diarization_and_embedding.py | 53 +------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index cf230773a..fa8897a07 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -39,7 +39,7 @@ from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.random import create_rng_for_worker -from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol +from pyannote.database.protocol import SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pyannote.audio.torchmetrics.classification import EqualErrorRate from pyannote.audio.torchmetrics import ( @@ -158,9 +158,6 @@ def setup(self, stage="fit"): if isinstance(self.protocol, SpeakerDiarizationProtocol): metadata_unique_values["scope"] = Scopes - elif isinstance(self.protocol, SegmentationProtocol): - classes = getattr(self, "classes", list()) - # make sure classes attribute exists (and set to None if it did not exist) self.classes = getattr(self, "classes", None) if self.classes is None: @@ -209,42 +206,6 @@ def setup(self, stage="fit"): elif file["scope"] in ["database", "file"]: self.diarization_database_files.append(file_id) - # keep track of list of classes for regular segmentation protocols - # Different files may be annotated using a different set of classes - # (e.g. one database for speech/music/noise, and another one for male/female/child) - if isinstance(self.protocol, SegmentationProtocol): - - if "classes" in file: - local_classes = file["classes"] - else: - local_classes = file["annotation"].labels() - - # if task was not initialized with a fixed list of classes, - # we build it as the union of all classes found in files - if self.classes is None: - for klass in local_classes: - if klass not in classes: - classes.append(klass) - annotated_classes.append( - [classes.index(klass) for klass in local_classes] - ) - - # if task was initialized with a fixed list of classes, - # we make sure that all files use a subset of these classes - # if they don't, we issue a warning and ignore the extra classes - else: - extra_classes = set(local_classes) - set(self.classes) - if extra_classes: - warnings.warn( - f"Ignoring extra classes ({', '.join(extra_classes)}) found for file {file['uri']} ({file['database']}). " - ) - annotated_classes.append( - [ - self.classes.index(klass) - for klass in set(local_classes) & set(self.classes) - ] - ) - remaining_metadata_keys = set(file) - set( [ "uri", @@ -415,18 +376,6 @@ def setup(self, stage="fit"): dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] self.annotated_regions = np.array(annotated_regions, dtype=dtype) - # convert annotated_classes (which is a list of list of classes, one list of classes per file) - # into a single (num_files x num_classes) numpy array: - # * True indicates that this particular class was annotated for this particular file (though it may not be active in this file) - # * False indicates that this particular class was not even annotated (i.e. its absence does not imply that it is not active in this file) - if isinstance(self.protocol, SegmentationProtocol) and self.classes is None: - self.classes = classes - self.annotated_classes = np.zeros( - (len(annotated_classes), len(self.classes)), dtype=np.bool_ - ) - for file_id, classes in enumerate(annotated_classes): - self.annotated_classes[file_id, classes] = True - # turn list of annotations into a single numpy array dtype = [ ("file_id", "i"), From 3363be66e76e0487af1983c2fc5f380042eac0a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Thu, 22 Jun 2023 17:32:59 +0200 Subject: [PATCH 19/83] improve(test): use pyannote.database.registry (#1413) --- .github/workflows/test.yml | 1 - README.md | 5 ++--- tests/conftest.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 tests/conftest.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b266179eb..df1182cf3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,5 +29,4 @@ jobs: pip install -e .[dev,testing] - name: Test with pytest run: | - export PYANNOTE_DATABASE_CONFIG=$GITHUB_WORKSPACE/tests/data/database.yml pytest diff --git a/README.md b/README.md index c3f9a8dcc..3bf1b2c8a 100644 --- a/README.md +++ b/README.md @@ -126,9 +126,8 @@ pip install -e .[dev,testing] pre-commit install ``` -Tests rely on a set of debugging files available in [`test/data`](test/data) directory. -Set `PYANNOTE_DATABASE_CONFIG` environment variable to `test/data/database.yml` before running tests: +## Test ```bash -PYANNOTE_DATABASE_CONFIG=tests/data/database.yml pytest +pytest ``` diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..fe2a00e12 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,32 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +def pytest_sessionstart(session): + """ + Called after the Session object has been created and + before performing collection and entering the run test loop. + """ + + from pyannote.database import registry + + registry.load_database("tests/data/database.yml") From 99a7762822a0732d4c187cb9cc98665ad7de27b1 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 23 Jun 2023 08:42:44 +0200 Subject: [PATCH 20/83] Set `alpha` coefficient as attribute --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index fa8897a07..5ceea7a8f 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -89,6 +89,7 @@ def __init__( pin_memory: bool = False, margin : float = 28.6, scale: float = 64.0, + alpha: float = 0.5, augmentation: BaseWaveformTransform = None ) -> None: super().__init__( @@ -105,6 +106,7 @@ def __init__( self.database_ratio = database_ratio self.margin = margin self.scale = scale + self.alpha = alpha # keep track of the use of database available in the meta protocol @@ -949,7 +951,7 @@ def training_step(self, batch, batch_idx: int): {"loss": loss} """ - alpha = 0.5 + alpha = self.alpha if batch["task"] == "diarization": # compute diarization loss diarization_loss = self.compute_diarization_loss(batch=batch) From f2a4e34a939bb9b7b01aac06f0cbdcf7ab40347d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 23 Jun 2023 09:10:12 +0200 Subject: [PATCH 21/83] remove `diarization_database_files` attribute as this instance attribute was not used --- .../speaker_diarization_and_embedding.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 5ceea7a8f..4978c9d73 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -113,7 +113,6 @@ def __init__( # * embedding databases are those with global speaker label scope # * diarization databases are those with file or database speaker label scope self.embedding_database_files = [] - self.diarization_database_files = [] def get_file(self, file_id): @@ -201,12 +200,9 @@ def setup(self, stage="fit"): # keep track of speaker label scope (file, database, or global) for speaker diarization protocols if isinstance(self.protocol, SpeakerDiarizationProtocol): metadatum["scope"] = Scopes.index(file["scope"]) - # add the file to the embedding or diarization list according to the file database speaker - # labels scope + # keep track of files where speaker label scope is global for embedding subtask if file["scope"] == 'global': self.embedding_database_files.append(file_id) - elif file["scope"] in ["database", "file"]: - self.diarization_database_files.append(file_id) remaining_metadata_keys = set(file) - set( [ @@ -566,18 +562,14 @@ def draw_diarization_chunk(self, file_ids : np.ndarray, return (file_id, start_time) - def draw_embedding_chunk(self, klass : Text, - classes : List[Text], + def draw_embedding_chunk(self, class_id : int, duration : float) -> tuple: """Sample one chunk for the embedding task Parameters ---------- - klass: Text - current class of speakers from which to draw a sample - classes: List[Text] - list of all the global speaker labels, in the same order than the list - defined in the task specification + class_id : int + class ID in the task speficiations duration: float duration of the chunk to draw @@ -590,7 +582,6 @@ def draw_embedding_chunk(self, klass : Text, start time of the sampled chunk """ # get index of the current class in the order of original class list - class_id = classes.index(klass) # get segments for current class class_segments_idx = self.annotations["global_label_idx"] == class_id class_segments = self.annotations[class_segments_idx] @@ -661,8 +652,9 @@ def train__iter__helper(self, rng : random.Random, **filters): rng.shuffle(shuffled_embedding_classes) embedding_class_idx = 0 klass = shuffled_embedding_classes[embedding_class_idx] + class_id = embedding_classes.index(klass) embedding_class_idx += 1 - file_id, start_time = self.draw_embedding_chunk(klass, embedding_classes, duration) + file_id, start_time = self.draw_embedding_chunk(class_id, duration) sample = self.prepare_chunk(file_id, start_time, duration, subtask) yield sample From 017c9108cf821be263b57a4e7536576f8edd2077 Mon Sep 17 00:00:00 2001 From: Dmitry Mukhutdinov Date: Fri, 23 Jun 2023 22:11:54 +0800 Subject: [PATCH 22/83] feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hervé BREDIN --- CHANGELOG.md | 4 +- pyannote/audio/pipelines/clustering.py | 61 ++++++++------ .../audio/pipelines/speaker_diarization.py | 81 +++++++++++++++---- pyannote/audio/pipelines/utils/diarization.py | 28 +++++-- 4 files changed, 128 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc51c0e9c..6e7d220fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,13 +30,13 @@ If, for some weird reason, you wrote some custom code based on that, you should instead rely on `Model.example_output`. - ### Features and improvements - feat(task): add support for multi-task models - feat(pipeline): send pipeline to device with `pipeline.to(device)` - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task + - feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline - feat(pipeline): add progress hook to pipelines - feat(pipeline): check version compatibility at load time - feat(task): add support for label scope in speaker diarization task @@ -88,7 +88,7 @@ - last release before complete rewriting -## Version 1.0.1 (2018--07-19) +## Version 1.0.1 (2018-07-19) - fix: fix regression in Precomputed.__call__ (#110, #105) diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index f282ea39c..3c2786232 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -48,7 +48,6 @@ def __init__( max_num_embeddings: int = 1000, constrained_assignment: bool = False, ): - super().__init__() self.metric = metric self.max_num_embeddings = max_num_embeddings @@ -61,7 +60,6 @@ def set_num_clusters( min_clusters: int = None, max_clusters: int = None, ): - min_clusters = num_clusters or min_clusters or 1 min_clusters = max(1, min(num_embeddings, min_clusters)) max_clusters = num_clusters or max_clusters or num_embeddings @@ -113,7 +111,6 @@ def filter_embeddings( return embeddings[chunk_idx, speaker_idx], chunk_idx, speaker_idx def constrained_argmax(self, soft_clusters: np.ndarray) -> np.ndarray: - soft_clusters = np.nan_to_num(soft_clusters, nan=np.nanmin(soft_clusters)) num_chunks, num_speakers, num_clusters = soft_clusters.shape # num_chunks, num_speakers, num_clusters @@ -156,6 +153,8 @@ def assign_embeddings( ------- soft_clusters : (num_chunks, num_speakers, num_clusters)-shaped array hard_clusters : (num_chunks, num_speakers)-shaped array + centroids : (num_clusters, dimension)-shaped array + Clusters centroids """ # TODO: option to add a new (dummy) cluster in case num_clusters < max(frame_speaker_count) @@ -194,7 +193,7 @@ def assign_embeddings( # TODO: add a flag to revert argmax for trainign subset # hard_clusters[train_chunk_idx, train_speaker_idx] = train_clusters - return hard_clusters, soft_clusters + return hard_clusters, soft_clusters, centroids def __call__( self, @@ -230,6 +229,8 @@ def __call__( soft_clusters : (num_chunks, num_speakers, num_clusters) array Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely the sth speaker of cth chunk belongs to kth cluster) + centroids : (num_clusters, dimension) array + Centroid vectors of each cluster """ train_embeddings, train_chunk_idx, train_speaker_idx = self.filter_embeddings( @@ -250,7 +251,9 @@ def __call__( num_chunks, num_speakers, _ = embeddings.shape hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8) soft_clusters = np.ones((num_chunks, num_speakers, 1)) - return hard_clusters, soft_clusters + centroids = np.mean(train_embeddings, axis=0, keepdims=True) + + return hard_clusters, soft_clusters, centroids train_clusters = self.cluster( train_embeddings, @@ -259,7 +262,7 @@ def __call__( num_clusters=num_clusters, ) - hard_clusters, soft_clusters = self.assign_embeddings( + hard_clusters, soft_clusters, centroids = self.assign_embeddings( embeddings, train_chunk_idx, train_speaker_idx, @@ -267,7 +270,7 @@ def __call__( constrained=self.constrained_assignment, ) - return hard_clusters, soft_clusters + return hard_clusters, soft_clusters, centroids class AgglomerativeClustering(BaseClustering): @@ -286,19 +289,6 @@ class AgglomerativeClustering(BaseClustering): Clustering threshold. min_cluster_size : int in range [1, 20] Minimum cluster size - - Usage - ----- - >>> clustering = AgglomerativeClustering(metric="cosine") - >>> clustering.instantiate({"method": "average", - ... "threshold": 1.0, - ... "min_cluster_size": 1}) - >>> clusters, _ = clustering(embeddings, # shape - ... num_clusters=None, - ... min_clusters=None, - ... max_clusters=None) - where `embeddings` is a np.ndarray with shape (num_embeddings, embedding_dimension) - and `clusters` is a np.ndarray with shape (num_embeddings, ) """ def __init__( @@ -307,7 +297,6 @@ def __init__( max_num_embeddings: int = np.inf, constrained_assignment: bool = False, ): - super().__init__( metric=metric, max_num_embeddings=max_num_embeddings, @@ -397,7 +386,6 @@ def cluster( num_clusters = max_clusters if num_clusters is not None: - # switch stopping criterion from "inter-cluster distance" stopping to "iteration index" _dendrogram = np.copy(dendrogram) _dendrogram[:, 2] = np.arange(num_embeddings - 1) @@ -409,7 +397,6 @@ def cluster( # from the "optimal" threshold for iteration in np.argsort(np.abs(dendrogram[:, 2] - self.threshold)): - # only consider iterations that might have resulted # in changing the number of (large) clusters new_cluster_size = _dendrogram[iteration, 3] @@ -481,6 +468,7 @@ class OracleClustering(BaseClustering): def __call__( self, + embeddings: np.ndarray = None, segmentations: SlidingWindowFeature = None, file: AudioFile = None, frames: SlidingWindow = None, @@ -490,6 +478,9 @@ def __call__( Parameters ---------- + embeddings : (num_chunks, num_speakers, dimension) array, optional + Sequence of embeddings. When provided, compute speaker centroids + based on these embeddings. segmentations : (num_chunks, num_frames, num_speakers) array Binary segmentations. file : AudioFile @@ -503,6 +494,8 @@ def __call__( soft_clusters : (num_chunks, num_speakers, num_clusters) array Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely the sth speaker of cth chunk belongs to kth cluster) + centroids : (num_clusters, dimension), optional + Clusters centroids if `embeddings` is provided, None otherwise. """ num_chunks, num_frames, num_speakers = segmentations.data.shape @@ -532,7 +525,27 @@ def __call__( hard_clusters[c, i] = j soft_clusters[c, i, j] = 1.0 - return hard_clusters, soft_clusters + if embeddings is None: + return hard_clusters, soft_clusters, None + + ( + train_embeddings, + train_chunk_idx, + train_speaker_idx, + ) = self.filter_embeddings( + embeddings, + segmentations=segmentations, + ) + + train_clusters = hard_clusters[train_chunk_idx, train_speaker_idx] + centroids = np.vstack( + [ + np.mean(train_embeddings[train_clusters == k], axis=0) + for k in range(num_clusters) + ] + ) + + return hard_clusters, soft_clusters, centroids class Clustering(Enum): diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 8cf30f3b9..f59551176 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -89,11 +89,20 @@ class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline): Usage ----- - >>> pipeline = SpeakerDiarization() + # perform (unconstrained) diarization >>> diarization = pipeline("/path/to/audio.wav") + + # perform diarization, targetting exactly 4 speakers >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) + + # perform diarization, with at least 2 speakers and at most 10 speakers >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + # perform diarization and get one representative embedding per speaker + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embedding=True) + >>> for s, speaker in enumerate(diarization.labels()): + ... # embeddings[s] is the embedding of speaker `speaker` + Hyper-parameters ---------------- segmentation.threshold @@ -417,6 +426,7 @@ def apply( num_speakers: int = None, min_speakers: int = None, max_speakers: int = None, + return_embeddings: bool = False, hook: Optional[Callable] = None, ) -> Annotation: """Apply speaker diarization @@ -431,6 +441,8 @@ def apply( Minimum number of speakers. Has no effect when `num_speakers` is provided. max_speakers : int, optional Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. hook : callable, optional Callback called after each major steps of the pipeline as follows: hook(step_name, # human-readable name of current step @@ -444,6 +456,10 @@ def apply( ------- diarization : Annotation Speaker diarization + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. """ # setup hook (e.g. for debugging purposes) @@ -473,7 +489,11 @@ def apply( # exit early when no speaker is ever active if np.nanmax(count.data) == 0.0: - return Annotation(uri=file["uri"]) + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, np.zeros((0, self._embedding.dimension)) + + return diarization # binarize segmentation if self._segmentation.model.specifications.powerset: @@ -485,7 +505,7 @@ def apply( initial_state=False, ) - if self.klustering == "OracleClustering": + if self.klustering == "OracleClustering" and not return_embeddings: embeddings = None else: embeddings = self.get_embeddings( @@ -497,7 +517,7 @@ def apply( hook("embeddings", embeddings) # shape: (num_chunks, local_num_speakers, dimension) - hard_clusters, _ = self.clustering( + hard_clusters, _, centroids = self.clustering( embeddings=embeddings, segmentations=binarized_segmentations, num_clusters=num_speakers, @@ -506,7 +526,8 @@ def apply( file=file, # <== for oracle clustering frames=self._frames, # <== for oracle clustering ) - # hard_clusters: (num_chunks, num_speakers) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) # reconstruct discrete diarization from raw hard clusters @@ -530,20 +551,52 @@ def apply( ) diarization.uri = file["uri"] - # when reference is available, use it to map hypothesized speakers - # to reference speakers (this makes later error analysis easier - # but does not modify the actual output of the diarization pipeline) + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + if "annotation" in file and file["annotation"]: - return self.optimal_mapping(file["annotation"], diarization) + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} - # when reference is not available, rename hypothesized speakers - # to human-readable SPEAKER_00, SPEAKER_01, ... - return diarization.rename_labels( - { + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { label: expected_label for label, expected_label in zip(diarization.labels(), self.classes()) } - ) + + diarization = diarization.rename_labels(mapping=mapping) + + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + + if not return_embeddings: + return diarization + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + # FIXME: the number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. Will be fixed in the future + + return diarization, centroids def get_metric(self) -> GreedyDiarizationErrorRate: return GreedyDiarizationErrorRate(**self.der_variant) diff --git a/pyannote/audio/pipelines/utils/diarization.py b/pyannote/audio/pipelines/utils/diarization.py index de07524e6..f494c6073 100644 --- a/pyannote/audio/pipelines/utils/diarization.py +++ b/pyannote/audio/pipelines/utils/diarization.py @@ -20,10 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Mapping, Tuple, Union +from typing import Dict, Mapping, Tuple, Union import numpy as np from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature +from pyannote.core.utils.types import Label from pyannote.metrics.diarization import DiarizationErrorRate from pyannote.audio.core.inference import Inference @@ -74,8 +75,10 @@ def set_num_speakers( @staticmethod def optimal_mapping( - reference: Union[Mapping, Annotation], hypothesis: Annotation - ) -> Annotation: + reference: Union[Mapping, Annotation], + hypothesis: Annotation, + return_mapping: bool = False, + ) -> Union[Annotation, Tuple[Annotation, Dict[Label, Label]]]: """Find the optimal bijective mapping between reference and hypothesis labels Parameters @@ -84,13 +87,19 @@ def optimal_mapping( Reference annotation. Can be an Annotation instance or a mapping with an "annotation" key. hypothesis : Annotation + Hypothesized annotation. + return_mapping : bool, optional + Return the label mapping itself along with the mapped annotation. Defaults to False. Returns ------- mapped : Annotation Hypothesis mapped to reference speakers. - + mapping : dict, optional + Mapping between hypothesis (key) and reference (value) labels + Only returned if `return_mapping` is True. """ + if isinstance(reference, Mapping): reference = reference["annotation"] annotated = reference["annotated"] if "annotated" in reference else None @@ -100,7 +109,13 @@ def optimal_mapping( mapping = DiarizationErrorRate().optimal_mapping( reference, hypothesis, uem=annotated ) - return hypothesis.rename_labels(mapping=mapping) + mapped_hypothesis = hypothesis.rename_labels(mapping=mapping) + + if return_mapping: + return mapped_hypothesis, mapping + + else: + return mapped_hypothesis # TODO: get rid of onset/offset (binarization should be applied before calling speaker_count) # TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count) @@ -171,7 +186,8 @@ def to_annotation( Returns ------- continuous_diarization : Annotation - Continuous diarization + Continuous diarization, with speaker labels as integers, + corresponding to the speaker indices in the discrete diarization. """ binarize = Binarize( From cf0e3b398bee4f3ae8ae0d5f13b54fec8cfd2ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 27 Jun 2023 13:41:41 +0200 Subject: [PATCH 23/83] fix: fix missed speech at the very beginning/end --- pyannote/audio/pipelines/speaker_diarization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index f59551176..0bc0f449a 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -482,6 +482,7 @@ def apply( if self._segmentation.model.specifications.powerset else self.segmentation.threshold, frames=self._frames, + warm_up=(0.0, 0.0), ) hook("speaker_counting", count) # shape: (num_frames, 1) From f48b74f0a0265f3794c709855b1f55bb6263da63 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 27 Jun 2023 21:53:46 +0200 Subject: [PATCH 24/83] add losses computation in `training_step` method --- .../models/joint/end_to_end_diarization.py | 11 +- .../speaker_diarization_and_embedding.py | 216 ++++++++++-------- 2 files changed, 124 insertions(+), 103 deletions(-) diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index e97e449e4..031b1fce9 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -27,7 +27,6 @@ import torch from torch import nn import torch.nn.functional as F -from pytorch_metric_learning.losses import ArcFaceLoss from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task @@ -189,14 +188,6 @@ def build(self): diarization_spec.powerset_max_classes, ) - self.arc_face_loss = ArcFaceLoss( - len(self.specifications[Subtasks.index("embedding")].classes), - self.hparams["embedding_dim"], - margin=self.task.margin, - scale=self.task.scale, - ) - - def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -234,7 +225,7 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non weights = self.powerset(diarization_outputs).transpose(1, 2) # embedding part: - embedding_outputs = torch.clone(common_outputs) + embedding_outputs = common_outputs for tdnn_block in self.tdnn_blocks[tdnn_idx:]: embedding_outputs = tdnn_block(embedding_outputs) embedding_outputs = self.stats_pool(embedding_outputs, weights=weights) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 4978c9d73..de4f059e7 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -33,6 +33,7 @@ from torchmetrics import Metric from torchmetrics.classification import BinaryAUROC from torch.utils.data._utils.collate import default_collate +from pytorch_metric_learning.losses import ArcFaceLoss from pyannote.core import Segment, SlidingWindowFeature from pyannote.audio.core.task import Problem, Resolution, Specifications, Task @@ -83,6 +84,7 @@ def __init__( duration: float = 5.0, max_speakers_per_chunk: int = 3, max_speakers_per_frame: int = 2, + weigh_by_cardinality: bool = False, batch_size: int = 32, database_ratio : float = 0.5, num_workers: int = None, @@ -101,6 +103,7 @@ def __init__( augmentation=augmentation, ) + self.weigh_by_cardinality = weigh_by_cardinality self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame self.database_ratio = database_ratio @@ -512,7 +515,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float, subtas sample["y"] = SlidingWindowFeature( y, self.model.example_output[subtask].frames, labels=labels ) - metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id @@ -717,22 +719,24 @@ def collate_y(self, batch) -> torch.Tensor: zeros (artificial inactive speakers). """ - collated_y = [] + collated_y_dia = [] + collate_y_emb = [] + for b in batch: - y = b["y"].data - num_speakers = len(b["y"].labels) + y_dia = b["y"].data + labels = b["y"].labels + num_speakers = len(labels) if num_speakers > self.max_speakers_per_chunk: # sort speakers in descending talkativeness order - indices = np.argsort(-np.sum(y, axis=0), axis=0) + indices = np.argsort(-np.sum(y_dia, axis=0), axis=0) # keep only the most talkative speakers - y = y[:, indices[: self.max_speakers_per_chunk]] - + y_dia = y_dia[:, indices[: self.max_speakers_per_chunk]] # TODO: we should also sort the speaker labels in the same way elif num_speakers < self.max_speakers_per_chunk: # create inactive speakers by zero padding - y = np.pad( - y, + y_dia = np.pad( + y_dia, ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), mode="constant", ) @@ -741,9 +745,16 @@ def collate_y(self, batch) -> torch.Tensor: # we have exactly the right number of speakers pass - collated_y.append(y) + # embedding reference + y_emb = np.full((self.max_speakers_per_chunk,), -1, dtype=np.int) + if b["meta"]["scope"] > 1: + y_emb[: len(labels)] = labels[:] + + collated_y_dia.append(y_dia) + collate_y_emb.append(y_emb) - return torch.from_numpy(np.stack(collated_y)) + return (torch.from_numpy(np.stack(collated_y_dia)), + torch.from_numpy(np.stack(collate_y_emb)).squeeze(1)) def collate_meta(self, batch) -> torch.Tensor: """Collate for metadata""" @@ -772,7 +783,7 @@ def collate_fn(self, batch, stage="train"): # collate X collated_X = self.collate_X(batch) # collate y - collated_y = self.collate_y(batch) + collated_y_dia, collate_y_emb = self.collate_y(batch) # collate metadata collated_meta = self.collate_meta(batch) @@ -782,15 +793,23 @@ def collate_fn(self, batch, stage="train"): augmented = self.augmentation( samples=collated_X, sample_rate=self.model.hparams.sample_rate, - targets=collated_y.unsqueeze(1), + targets=collated_y_dia.unsqueeze(1), ) return { "X": augmented.samples, - "y": augmented.targets.squeeze(1), + "y_dia": augmented.targets.squeeze(1), + "y_emb": collate_y_emb, "meta": collated_meta, } + def setup_loss_func(self): + self.model.arc_face_loss = ArcFaceLoss( + len(self.specifications[Subtasks.index("embedding")].classes), + self.model.hparams["embedding_dim"], + margin=self.margin, + scale=self.scale, + ) def segmentation_loss( self, @@ -830,102 +849,77 @@ def segmentation_loss( return seg_loss - def compute_diarization_loss(self, batch : torch.Tensor): + def compute_diarization_loss(self, dia_chunks_idx, dia_prediction, permutated_target): """""" - X, y = batch["X"], batch["y"] - # drop samples that contain too many speakers - num_speakers: torch.Tensor = torch.sum(torch.any(target, dim=1), dim=1) - keep : torch.Tensor = num_speakers <= self.max_speakers_per_chunk - target = target[keep] - # TODO using variable `waveform` before assignment - waveform = waveform[keep] - - # log effective batch size - self.model.log( - f"{self.logging_prefix}BatchSize", - keep.sum(), - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - reduce_fx="mean", - ) - # corner case - if not keep.any(): - return None - - # forward pass - prediction = self.model(waveform) - batch_size, num_frames, _ = prediction.shape - # (batch_size, num_frames, num_classes) - # frames weight - weight_key = getattr(self, "weight", None) - weight = batch.get( - weight_key, - torch.ones(batch_size, num_frames, 1, device=self.model.device), - ) - # (batch_size, num_frames, 1) - - # warm-up - warm_up_left = round(self.warm_up[0] / self.duration * num_frames) - weight[:, :warm_up_left] = 0.0 - warm_up_right = round(self.warm_up[1] / self.duration * num_frames) - weight[:, num_frames - warm_up_right :] = 0.0 - - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) - permutated_target, _ = permutate(multilabel, target) - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) + # Get chunks corresponding to the diarization subtask + diarization_chunks = dia_prediction[dia_chunks_idx] + # Get the permutated reference corresponding to diarization subtask + permutated_target_dia = permutated_target[dia_chunks_idx] + # Compute segmentation loss + diarization_loss = self.segmentation_loss(diarization_chunks, + permutated_target_dia) self.model.log( - f"{self.logging_prefix}TrainSegLoss", - seg_loss, + "loss/val", + diarization_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, ) - loss = seg_loss - # skip batch if something went wrong for some reason - if torch.isnan(loss): - return None - - self.model.log( - f"{self.logging_prefix}TrainLoss", - loss, + self.model.log_dict( + self.model.validation_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) - return {"loss": loss} + return diarization_loss - def compute_embedding_loss(self, batch : torch.Tensor): - X, y = batch["X", batch["y"]] - loss = self.model.loss_func(self.model(X), y) + def compute_embedding_loss(self, emb_chunks_idx, emb_prediction, target_emb , permut_map): + """""" + + global_spks_id = [] + embeddings = torch.Tensor(device=self.model.device) + + emb_chunks = emb_prediction[emb_chunks_idx] + # (num_emb_chunks, num_spk, emb_dim) + target_chunks_classes = target_emb[emb_chunks_idx] + # (num_emb_chunk, num_spk) + permut_map_emb = permut_map[emb_chunks_idx] + # (num_emb_chunk, num_spk) + + for chunk_id in range(emb_chunks.shape[0]): + current_chunk_target = target_chunks_classes[chunk_id] + + # for each active speaker in the chunk + for chunk_spk_id in np.argwhere(current_chunk_target != -1).reshape((-1,)): + # get permutation map for current chunk + permut_map_chunk = permut_map_emb[chunk_id] + # get embedding index in the prediction + emb_idx = np.where(permut_map_chunk == int(chunk_spk_id))[0] + global_spks_id.append(int(current_chunk_target[chunk_spk_id])) + chunk_spk_emb = emb_chunks[chunk_id, emb_idx , :] + # add current embedding to the tensor of embeddings + embeddings = torch.concat((embeddings,chunk_spk_emb), dim=0) + + global_spks_id = torch.tensor(global_spks_id, device=self.model.device, dtype=torch.int64) + embedding_loss = self.model.arc_face_loss(embeddings, global_spks_id) # skip batch if something went wrong for some reason - if torch.isnan(loss): + if torch.isnan(embedding_loss): return None self.model.log( - f"{self.logging_prefix}TrainLoss", - loss, + "loss/val/arcface", + embedding_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) - return {"loss": loss} + return embedding_loss def training_step(self, batch, batch_idx: int): """Compute loss for the joint task @@ -942,15 +936,51 @@ def training_step(self, batch, batch_idx: int): loss : {str: torch.tensor} {"loss": loss} """ - alpha = self.alpha - if batch["task"] == "diarization": - # compute diarization loss - diarization_loss = self.compute_diarization_loss(batch=batch) - if batch["task"] == "embedding": - # compute embedding loss - embedding_loss = self.compute_embedding_loss(batch=batch) - loss = alpha * diarization_loss + (1 - alpha) * embedding_loss + # batch waveforms (batch_size, num_channels, num_samples) + waveform = batch["X"] + # batch diarization references (batch_size, num_channels, num_speakers) + target_dia = batch["y_dia"] + # batch embedding references (batch, num_speakers) + target_emb = batch["y_emb"] + + # drop samples that contain too many speakers + num_speakers: torch.Tensor = torch.sum(torch.any(target_dia, dim=1), dim=1) + keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + target_dia = target_dia[keep] + target_emb = target_emb[keep] + waveform = waveform[keep] + + # corner case + if not keep.any(): + return None + + # forward pass + dia_prediction, emb_prediction = self.model(waveform) + # (batch_size, num_frames, num_spk), (batch_size, num_spk, emb_size) + + # get the best permutation + dia_multilabel = self.model.powerset.to_multilabel(dia_prediction) + permutated_target, permut_map = permutate(dia_multilabel, target_dia) + permut_map = np.array(permut_map) + + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + + # Get chunk indexes in the batch for each subtask + emb_chunks_idx = np.nonzero(torch.any(target_emb != -1, axis=1)).reshape((-1,)) + dia_chunks_idx = np.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) + + dia_loss = self.compute_diarization_loss(dia_chunks_idx, dia_prediction, permutated_target_powerset) + + emb_loss = 0 + # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: + if emb_chunks_idx.any(): + emb_loss = self.compute_embedding_loss(emb_chunks_idx, emb_prediction, target_emb, permut_map) + + loss = alpha * dia_loss + (1 - alpha) * emb_loss + return {"loss": loss} def default_metric( self, From f3935464065f743496df9558b7badaaa5a827c9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Wed, 28 Jun 2023 09:09:53 +0200 Subject: [PATCH 25/83] doc: add note to self regarding cluster reassignment (#1419) --- pyannote/audio/pipelines/clustering.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index 3c2786232..a779016cb 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -190,8 +190,9 @@ def assign_embeddings( else: hard_clusters = np.argmax(soft_clusters, axis=2) - # TODO: add a flag to revert argmax for trainign subset - # hard_clusters[train_chunk_idx, train_speaker_idx] = train_clusters + # NOTE: train_embeddings might be reassigned to a different cluster + # in the process. based on experiments, this seems to lead to better + # results than sticking to the original assignment. return hard_clusters, soft_clusters, centroids From 57185935087d476af348610451ea06c31b5db17f Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 28 Jun 2023 16:26:45 +0200 Subject: [PATCH 26/83] remove for loops in embedding loss computation as these loop could break gradient flow and to optimize the code --- .../speaker_diarization_and_embedding.py | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index de4f059e7..115c411dc 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -879,32 +879,18 @@ def compute_diarization_loss(self, dia_chunks_idx, dia_prediction, permutated_ta def compute_embedding_loss(self, emb_chunks_idx, emb_prediction, target_emb , permut_map): """""" - - global_spks_id = [] - embeddings = torch.Tensor(device=self.model.device) - - emb_chunks = emb_prediction[emb_chunks_idx] - # (num_emb_chunks, num_spk, emb_dim) - target_chunks_classes = target_emb[emb_chunks_idx] - # (num_emb_chunk, num_spk) permut_map_emb = permut_map[emb_chunks_idx] # (num_emb_chunk, num_spk) + + # get all active speakers in embedding task chunks target + chunks_spk_id = torch.argwhere(target_emb != -1)[:, 1] + # Get corresponding embeddings indexes in chunks predictions + emb_idx = torch.where(permut_map_emb == chunks_spk_id.reshape((-1, 1))) + # Get the speaker embeddings + embeddings = emb_prediction[emb_idx[0], emb_idx[1]] + # Get global speaker idx + global_spks_id = target_emb[chunks_spk_id[:, 0], chunks_spk_id[:, 1]] - for chunk_id in range(emb_chunks.shape[0]): - current_chunk_target = target_chunks_classes[chunk_id] - - # for each active speaker in the chunk - for chunk_spk_id in np.argwhere(current_chunk_target != -1).reshape((-1,)): - # get permutation map for current chunk - permut_map_chunk = permut_map_emb[chunk_id] - # get embedding index in the prediction - emb_idx = np.where(permut_map_chunk == int(chunk_spk_id))[0] - global_spks_id.append(int(current_chunk_target[chunk_spk_id])) - chunk_spk_emb = emb_chunks[chunk_id, emb_idx , :] - # add current embedding to the tensor of embeddings - embeddings = torch.concat((embeddings,chunk_spk_emb), dim=0) - - global_spks_id = torch.tensor(global_spks_id, device=self.model.device, dtype=torch.int64) embedding_loss = self.model.arc_face_loss(embeddings, global_spks_id) # skip batch if something went wrong for some reason @@ -962,15 +948,15 @@ def training_step(self, batch, batch_idx: int): # get the best permutation dia_multilabel = self.model.powerset.to_multilabel(dia_prediction) permutated_target, permut_map = permutate(dia_multilabel, target_dia) - permut_map = np.array(permut_map) + permut_map = torch.tensor(data=permut_map, device=self.model.device) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target.float() ) # Get chunk indexes in the batch for each subtask - emb_chunks_idx = np.nonzero(torch.any(target_emb != -1, axis=1)).reshape((-1,)) - dia_chunks_idx = np.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) + emb_chunks_idx = torch.nonzero(torch.any(target_emb != -1, axis=1)).reshape((-1,)) + dia_chunks_idx = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) dia_loss = self.compute_diarization_loss(dia_chunks_idx, dia_prediction, permutated_target_powerset) From 80365729ecc18ee3ecd9ba321d62641a6b9f0479 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 3 Jul 2023 15:40:20 +0200 Subject: [PATCH 27/83] add validation part into the multi-task --- .../speaker_diarization_and_embedding.py | 212 +++++++++++++++++- 1 file changed, 209 insertions(+), 3 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 115c411dc..7028eafaf 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -23,7 +23,7 @@ from collections import defaultdict import itertools import random -from typing import Literal, List, Text, Union, Sequence, Dict +from typing import Literal, Union, Sequence, Dict import warnings import numpy as np import torch @@ -36,10 +36,11 @@ from pytorch_metric_learning.losses import ArcFaceLoss from pyannote.core import Segment, SlidingWindowFeature -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task +from pyannote.audio.core.task import Problem, Resolution, Specifications from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.random import create_rng_for_worker +from pyannote.audio.tasks import SpeakerDiarization from pyannote.database.protocol import SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pyannote.audio.torchmetrics.classification import EqualErrorRate @@ -57,7 +58,7 @@ Subtasks = list(Subtask.__args__) -class JointSpeakerDiarizationAndEmbedding(Task): +class JointSpeakerDiarizationAndEmbedding(SpeakerDiarization): """Joint speaker diarization and embedding task Usage @@ -967,6 +968,211 @@ def training_step(self, batch, batch_idx: int): loss = alpha * dia_loss + (1 - alpha) * emb_loss return {"loss": loss} + + # TODO: no need to compute gradient in this method + def validation_step(self, batch, batch_idx: int): + """Compute validation loss and metric + + Parameters + ---------- + batch : dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + """ + + # target + target = batch["y"] + # (batch_size, num_frames, num_speakers) + + waveform = batch["X"] + # (batch_size, num_channels, num_samples) + + # TODO: should we handle validation samples with too many speakers + # waveform = waveform[keep] + # target = target[keep] + + # forward pass + prediction = self.model(waveform) + batch_size, num_frames, _ = prediction.shape + + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + # warm-up + warm_up_left = round(self.warm_up[0] / self.duration * num_frames) + weight[:, :warm_up_left] = 0.0 + warm_up_right = round(self.warm_up[1] / self.duration * num_frames) + weight[:, num_frames - warm_up_right :] = 0.0 + + if self.specifications.powerset: + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) + + # FIXME: handle case where target have too many speakers? + # since we don't need + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + permutated_prediction, _ = permutate(target, prediction) + seg_loss = self.segmentation_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + "loss/val/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.vad_loss is None: + vad_loss = 0.0 + + else: + # TODO: vad_loss probably does not make sense in powerset mode + # because first class (empty set of labels) does exactly this... + if self.specifications.powerset: + vad_loss = self.voice_activity_detection_loss( + prediction, permutated_target_powerset, weight=weight + ) + + else: + vad_loss = self.voice_activity_detection_loss( + permutated_prediction, target, weight=weight + ) + + self.model.log( + "loss/val/vad", + vad_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = seg_loss + vad_loss + + self.model.log( + "loss/val", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.specifications.powerset: + self.model.validation_metric( + torch.transpose( + multilabel[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + torch.transpose( + target[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + ) + else: + self.model.validation_metric( + torch.transpose( + prediction[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + torch.transpose( + target[:, warm_up_left : num_frames - warm_up_right], 1, 2 + ), + ) + + self.model.log_dict( + self.model.validation_metric, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + # log first batch visualization every 2^n epochs. + if ( + self.model.current_epoch == 0 + or math.log2(self.model.current_epoch) % 1 > 0 + or batch_idx > 0 + ): + return + + # visualize first 9 validation samples of first batch in Tensorboard/MLflow + + if self.specifications.powerset: + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() + else: + y = target.float().cpu().numpy() + y_pred = permutated_prediction.cpu().numpy() + + # prepare 3 x 3 grid (or smaller if batch size is smaller) + num_samples = min(self.batch_size, 9) + nrows = math.ceil(math.sqrt(num_samples)) + ncols = math.ceil(num_samples / nrows) + fig, axes = plt.subplots( + nrows=2 * nrows, ncols=ncols, figsize=(8, 5), squeeze=False + ) + + # reshape target so that there is one line per class when plotting it + y[y == 0] = np.NaN + if len(y.shape) == 2: + y = y[:, :, np.newaxis] + y *= np.arange(y.shape[2]) + + # plot each sample + for sample_idx in range(num_samples): + # find where in the grid it should be plotted + row_idx = sample_idx // nrows + col_idx = sample_idx % ncols + + # plot target + ax_ref = axes[row_idx * 2 + 0, col_idx] + sample_y = y[sample_idx] + ax_ref.plot(sample_y) + ax_ref.set_xlim(0, len(sample_y)) + ax_ref.set_ylim(-1, sample_y.shape[1]) + ax_ref.get_xaxis().set_visible(False) + ax_ref.get_yaxis().set_visible(False) + + # plot predictions + ax_hyp = axes[row_idx * 2 + 1, col_idx] + sample_y_pred = y_pred[sample_idx] + ax_hyp.axvspan(0, warm_up_left, color="k", alpha=0.5, lw=0) + ax_hyp.axvspan( + num_frames - warm_up_right, num_frames, color="k", alpha=0.5, lw=0 + ) + ax_hyp.plot(sample_y_pred) + ax_hyp.set_ylim(-0.1, 1.1) + ax_hyp.set_xlim(0, len(sample_y)) + ax_hyp.get_xaxis().set_visible(False) + + plt.tight_layout() + + for logger in self.model.loggers: + if isinstance(logger, TensorBoardLogger): + logger.experiment.add_figure("samples", fig, self.model.current_epoch) + elif isinstance(logger, MLFlowLogger): + logger.experiment.log_figure( + run_id=logger.run_id, + figure=fig, + artifact_file=f"samples_epoch{self.model.current_epoch}.png", + ) + + plt.close(fig) + def default_metric( self, From aa36d7be51fb22a4994f0651993da7d4f5d8933e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 4 Jul 2023 22:44:59 +0200 Subject: [PATCH 28/83] remove `subtask` parameter from `prepare_chunk` --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 7028eafaf..73bab8530 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -440,7 +440,7 @@ def setup(self, stage="fit"): self.specifications = (speaker_diarization, speaker_embedding) - def prepare_chunk(self, file_id: int, start_time: float, duration: float, subtask: int): + def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk Parameters @@ -514,12 +514,11 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float, subtas y[start:end, mapped_label] = 1 sample["y"] = SlidingWindowFeature( - y, self.model.example_output[subtask].frames, labels=labels + y, self.model.example_output[0].frames, labels=labels ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id - sample["meta"]["subtask"] = subtask return sample @@ -659,7 +658,7 @@ def train__iter__helper(self, rng : random.Random, **filters): embedding_class_idx += 1 file_id, start_time = self.draw_embedding_chunk(class_id, duration) - sample = self.prepare_chunk(file_id, start_time, duration, subtask) + sample = self.prepare_chunk(file_id, start_time, duration) yield sample def train__iter__(self): From 6617c9c50e3bbe20f05d9820463432b3b5014ad2 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 4 Jul 2023 22:46:08 +0200 Subject: [PATCH 29/83] fix bugs in validation part for now do the trick only for the diarization subtask --- .../speaker_diarization_and_embedding.py | 146 ++++++------------ 1 file changed, 44 insertions(+), 102 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 73bab8530..e193c8c74 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -21,10 +21,12 @@ # SOFTWARE. from collections import defaultdict +import math import itertools import random from typing import Literal, Union, Sequence, Dict import warnings +from matplotlib import pyplot as plt import numpy as np import torch @@ -34,6 +36,7 @@ from torchmetrics.classification import BinaryAUROC from torch.utils.data._utils.collate import default_collate from pytorch_metric_learning.losses import ArcFaceLoss +from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger from pyannote.core import Segment, SlidingWindowFeature from pyannote.audio.core.task import Problem, Resolution, Specifications @@ -881,14 +884,16 @@ def compute_embedding_loss(self, emb_chunks_idx, emb_prediction, target_emb , pe """""" permut_map_emb = permut_map[emb_chunks_idx] # (num_emb_chunk, num_spk) - + emb_chunks = emb_prediction[emb_chunks_idx] + + # TODO : to be simplified # get all active speakers in embedding task chunks target - chunks_spk_id = torch.argwhere(target_emb != -1)[:, 1] - # Get corresponding embeddings indexes in chunks predictions - emb_idx = torch.where(permut_map_emb == chunks_spk_id.reshape((-1, 1))) - # Get the speaker embeddings - embeddings = emb_prediction[emb_idx[0], emb_idx[1]] - # Get global speaker idx + chunks_spk_id = torch.argwhere(target_emb != -1) + # get corresponding embeddings indexes in chunks predictions + emb_idx = torch.where(permut_map_emb == chunks_spk_id[:, 1].reshape((-1, 1))) + # get the speaker embeddings + embeddings = emb_chunks[emb_idx[0], emb_idx[1]] + # get global speaker idx global_spks_id = target_emb[chunks_spk_id[:, 0], chunks_spk_id[:, 1]] embedding_loss = self.model.arc_face_loss(embeddings, global_spks_id) @@ -954,7 +959,7 @@ def training_step(self, batch, batch_idx: int): permutated_target.float() ) - # Get chunk indexes in the batch for each subtask + # get chunk indexes in the batch for each subtask emb_chunks_idx = torch.nonzero(torch.any(target_emb != -1, axis=1)).reshape((-1,)) dia_chunks_idx = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) @@ -962,13 +967,13 @@ def training_step(self, batch, batch_idx: int): emb_loss = 0 # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: - if emb_chunks_idx.any(): + if emb_chunks_idx.shape[0] > 0: emb_loss = self.compute_embedding_loss(emb_chunks_idx, emb_prediction, target_emb, permut_map) loss = alpha * dia_loss + (1 - alpha) * emb_loss return {"loss": loss} - - # TODO: no need to compute gradient in this method + + # TODO: no need to compute gradient in this method def validation_step(self, batch, batch_idx: int): """Compute validation loss and metric @@ -981,8 +986,9 @@ def validation_step(self, batch, batch_idx: int): """ # target - target = batch["y"] + target_dia = batch["y_dia"] # (batch_size, num_frames, num_speakers) + target_emb = batch["y_emb"] waveform = batch["X"] # (batch_size, num_channels, num_samples) @@ -992,41 +998,20 @@ def validation_step(self, batch, batch_idx: int): # target = target[keep] # forward pass - prediction = self.model(waveform) - batch_size, num_frames, _ = prediction.shape - - # frames weight - weight_key = getattr(self, "weight", None) - weight = batch.get( - weight_key, - torch.ones(batch_size, num_frames, 1, device=self.model.device), - ) - # (batch_size, num_frames, 1) - - # warm-up - warm_up_left = round(self.warm_up[0] / self.duration * num_frames) - weight[:, :warm_up_left] = 0.0 - warm_up_right = round(self.warm_up[1] / self.duration * num_frames) - weight[:, num_frames - warm_up_right :] = 0.0 - - if self.specifications.powerset: - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) - - # FIXME: handle case where target have too many speakers? - # since we don't need - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) + dia_prediction, emb_prediction = self.model(waveform) + batch_size, num_frames, _ = dia_prediction.shape - else: - permutated_prediction, _ = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) + multilabel = self.model.powerset.to_multilabel(dia_prediction) + permutated_target, _ = permutate(multilabel, target_dia) + + # FIXME: handle case where target have too many speakers? + # since we don't need + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) + seg_loss = self.segmentation_loss( + dia_prediction, permutated_target_powerset, + ) self.model.log( "loss/val/segmentation", @@ -1037,32 +1022,7 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - if self.vad_loss is None: - vad_loss = 0.0 - - else: - # TODO: vad_loss probably does not make sense in powerset mode - # because first class (empty set of labels) does exactly this... - if self.specifications.powerset: - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - vad_loss = self.voice_activity_detection_loss( - permutated_prediction, target, weight=weight - ) - - self.model.log( - "loss/val/vad", - vad_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - loss = seg_loss + vad_loss + loss = seg_loss self.model.log( "loss/val", @@ -1073,24 +1033,14 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - if self.specifications.powerset: - self.model.validation_metric( - torch.transpose( - multilabel[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - torch.transpose( - target[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - ) - else: - self.model.validation_metric( - torch.transpose( - prediction[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - torch.transpose( - target[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - ) + self.model.validation_metric( + torch.transpose( + multilabel, 1, 2 + ), + torch.transpose( + target_dia, 1, 2 + ), + ) self.model.log_dict( self.model.validation_metric, @@ -1110,12 +1060,8 @@ def validation_step(self, batch, batch_idx: int): # visualize first 9 validation samples of first batch in Tensorboard/MLflow - if self.specifications.powerset: - y = permutated_target.float().cpu().numpy() - y_pred = multilabel.cpu().numpy() - else: - y = target.float().cpu().numpy() - y_pred = permutated_prediction.cpu().numpy() + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() # prepare 3 x 3 grid (or smaller if batch size is smaller) num_samples = min(self.batch_size, 9) @@ -1149,10 +1095,6 @@ def validation_step(self, batch, batch_idx: int): # plot predictions ax_hyp = axes[row_idx * 2 + 1, col_idx] sample_y_pred = y_pred[sample_idx] - ax_hyp.axvspan(0, warm_up_left, color="k", alpha=0.5, lw=0) - ax_hyp.axvspan( - num_frames - warm_up_right, num_frames, color="k", alpha=0.5, lw=0 - ) ax_hyp.plot(sample_y_pred) ax_hyp.set_ylim(-0.1, 1.1) ax_hyp.set_xlim(0, len(sample_y)) @@ -1184,6 +1126,6 @@ def default_metric( "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), - "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), + #"EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), + #"BinaryAUROC": BinaryAUROC(compute_on_cpu=True), } From 60d5543a087899b48935760b964cf72364097b8d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 5 Jul 2023 15:45:48 +0200 Subject: [PATCH 30/83] simplify the way embedding loss is calculated --- .../speaker_diarization_and_embedding.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index e193c8c74..b94093a3b 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -21,6 +21,7 @@ # SOFTWARE. from collections import defaultdict +from einops import rearrange import math import itertools import random @@ -863,7 +864,7 @@ def compute_diarization_loss(self, dia_chunks_idx, dia_prediction, permutated_ta diarization_loss = self.segmentation_loss(diarization_chunks, permutated_target_dia) self.model.log( - "loss/val", + "loss/train/segmentation", diarization_loss, on_step=False, on_epoch=True, @@ -880,30 +881,23 @@ def compute_diarization_loss(self, dia_chunks_idx, dia_prediction, permutated_ta ) return diarization_loss - def compute_embedding_loss(self, emb_chunks_idx, emb_prediction, target_emb , permut_map): + def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): """""" - permut_map_emb = permut_map[emb_chunks_idx] - # (num_emb_chunk, num_spk) - emb_chunks = emb_prediction[emb_chunks_idx] - - # TODO : to be simplified - # get all active speakers in embedding task chunks target - chunks_spk_id = torch.argwhere(target_emb != -1) - # get corresponding embeddings indexes in chunks predictions - emb_idx = torch.where(permut_map_emb == chunks_spk_id[:, 1].reshape((-1, 1))) - # get the speaker embeddings - embeddings = emb_chunks[emb_idx[0], emb_idx[1]] - # get global speaker idx - global_spks_id = target_emb[chunks_spk_id[:, 0], chunks_spk_id[:, 1]] - - embedding_loss = self.model.arc_face_loss(embeddings, global_spks_id) + + # Get speaker representations from the embedding subtask + embeddings = emb_prediction[emb_chunks] + # Get corresponding target label + targets = target_emb[emb_chunks] + print(targets) + # compute the loss + embedding_loss = self.model.arc_face_loss(embeddings, targets) # skip batch if something went wrong for some reason if torch.isnan(embedding_loss): return None self.model.log( - "loss/val/arcface", + "loss/train/arcface", embedding_loss, on_step=False, on_epoch=True, @@ -952,23 +946,29 @@ def training_step(self, batch, batch_idx: int): # get the best permutation dia_multilabel = self.model.powerset.to_multilabel(dia_prediction) - permutated_target, permut_map = permutate(dia_multilabel, target_dia) - permut_map = torch.tensor(data=permut_map, device=self.model.device) + permutated_target_dia, permut_map = permutate(dia_multilabel, target_dia) + permutated_target_emb = target_emb[torch.arange(target_emb.shape[0]).unsqueeze(1), + permut_map] + + emb_prediction = rearrange(emb_prediction, "b s e -> (b s) e") + permutated_target_emb = rearrange(permutated_target_emb, "b s -> (b s)") permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() + permutated_target_dia.float() ) - # get chunk indexes in the batch for each subtask - emb_chunks_idx = torch.nonzero(torch.any(target_emb != -1, axis=1)).reshape((-1,)) - dia_chunks_idx = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) + # get embedding chunks position in current batch + emb_chunks = permutated_target_emb != -1 + # get diarization chunks position in current batch (that correspond to non embedding chunks) + dia_chunks = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) + #dia_chunks = emb_chunks == False - dia_loss = self.compute_diarization_loss(dia_chunks_idx, dia_prediction, permutated_target_powerset) + dia_loss = self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) emb_loss = 0 # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: - if emb_chunks_idx.shape[0] > 0: - emb_loss = self.compute_embedding_loss(emb_chunks_idx, emb_prediction, target_emb, permut_map) + if emb_chunks.shape[0] > 0: + emb_loss = self.compute_embedding_loss(emb_chunks, emb_prediction, permutated_target_emb) loss = alpha * dia_loss + (1 - alpha) * emb_loss return {"loss": loss} From 2834d3e3b623e7bb609d87060f7e8231aa545e1e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 6 Jul 2023 14:12:34 +0200 Subject: [PATCH 31/83] handle case where there is no files from diarization dataset --- .../speaker_diarization_and_embedding.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index b94093a3b..2c97e30b7 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -631,8 +631,15 @@ def train__iter__helper(self, rng : random.Random, **filters): annotated_duration = self.annotated_duration[file_ids] # set duration of files for the embedding part to zero, in order to not # drawn them for diarization part - annotated_duration[embedding_files_ids] = 0 - prob_annotated_duration = annotated_duration / np.sum(annotated_duration) + annotated_duration[embedding_files_ids] = 0. + # test if there is at least one file for the diarization subtask to avoid + # to prevent probabilities from summing to zero + if np.any(annotated_duration != 0.): + prob_annotated_duration = annotated_duration / np.sum(annotated_duration) + else: + # There is only files for the embedding subtask + self.database_ratio = 0. + self.alpha = 0. duration = self.duration @@ -888,7 +895,6 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): embeddings = emb_prediction[emb_chunks] # Get corresponding target label targets = target_emb[emb_chunks] - print(targets) # compute the loss embedding_loss = self.model.arc_face_loss(embeddings, targets) @@ -963,7 +969,10 @@ def training_step(self, batch, batch_idx: int): dia_chunks = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) #dia_chunks = emb_chunks == False - dia_loss = self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) + dia_loss = 0 + #if batch contains diarization subtask chunks, then compute diarization loss on these chunks: + if dia_chunks.shape[0] > 0: + self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) emb_loss = 0 # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: From 35be745d3d77058fca1eaf318a3a060a0841bcfe Mon Sep 17 00:00:00 2001 From: ~/trisqwit <34287923+DiaaAj@users.noreply.github.com> Date: Sun, 9 Jul 2023 15:07:35 +0300 Subject: [PATCH 32/83] fix(doc): fix typo in diarization docstring --- pyannote/audio/pipelines/speaker_diarization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 0bc0f449a..18b6565d3 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -99,7 +99,7 @@ class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline): >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) # perform diarization and get one representative embedding per speaker - >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embedding=True) + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) >>> for s, speaker in enumerate(diarization.labels()): ... # embeddings[s] is the embedding of speaker `speaker` From 5628b48278ca2e37338a135849f8c4101e06ecf8 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 11 Jul 2023 16:21:06 +0200 Subject: [PATCH 33/83] fix size issue in `collate_y` when building embedding ref There was an issue when the number of speakers in a chunk was greater than the maximum number per chunk set for the task. --- .../speaker_diarization_and_embedding.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 2c97e30b7..09a2a2b18 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -734,9 +734,13 @@ def collate_y(self, batch) -> torch.Tensor: collate_y_emb = [] for b in batch: + # diarization reference y_dia = b["y"].data labels = b["y"].labels num_speakers = len(labels) + # embedding reference + y_emb = np.full((self.max_speakers_per_chunk,), -1, dtype=np.int) + if num_speakers > self.max_speakers_per_chunk: # sort speakers in descending talkativeness order indices = np.argsort(-np.sum(y_dia, axis=0), axis=0) @@ -744,6 +748,11 @@ def collate_y(self, batch) -> torch.Tensor: y_dia = y_dia[:, indices[: self.max_speakers_per_chunk]] # TODO: we should also sort the speaker labels in the same way + # if current chunck is for the embedding subtask + if b["meta"]["scope"] > 1: + labels = np.array(labels) + y_emb = labels[indices[: self.max_speakers_per_chunk]] + elif num_speakers < self.max_speakers_per_chunk: # create inactive speakers by zero padding y_dia = np.pad( @@ -751,15 +760,12 @@ def collate_y(self, batch) -> torch.Tensor: ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), mode="constant", ) + if b["meta"]["scope"] > 1: + y_emb[: num_speakers] = labels[:] else: - # we have exactly the right number of speakers - pass - - # embedding reference - y_emb = np.full((self.max_speakers_per_chunk,), -1, dtype=np.int) - if b["meta"]["scope"] > 1: - y_emb[: len(labels)] = labels[:] + if b["meta"]["scope"] > 1: + y_emb[: num_speakers] = labels[:] collated_y_dia.append(y_dia) collate_y_emb.append(y_emb) From c4988f4553d94de994f6bd5e1c42e1f468dc3546 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 12 Jul 2023 08:45:43 +0200 Subject: [PATCH 34/83] fix condition to compute `emb_loss` in `training_step` --- .../speaker_diarization_and_embedding.py | 29 ++++--------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 09a2a2b18..986125153 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -877,17 +877,9 @@ def compute_diarization_loss(self, dia_chunks_idx, dia_prediction, permutated_ta diarization_loss = self.segmentation_loss(diarization_chunks, permutated_target_dia) self.model.log( - "loss/train/segmentation", + "loss/train/dia", diarization_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - self.model.log_dict( - self.model.validation_metric, - on_step=False, + on_step=True, on_epoch=True, prog_bar=True, logger=True, @@ -911,7 +903,7 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): self.model.log( "loss/train/arcface", embedding_loss, - on_step=False, + on_step=True, on_epoch=True, prog_bar=True, logger=True, @@ -982,7 +974,7 @@ def training_step(self, batch, batch_idx: int): emb_loss = 0 # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: - if emb_chunks.shape[0] > 0: + if emb_chunks.any(): emb_loss = self.compute_embedding_loss(emb_chunks, emb_prediction, permutated_target_emb) loss = alpha * dia_loss + (1 - alpha) * emb_loss @@ -1029,7 +1021,7 @@ def validation_step(self, batch, batch_idx: int): ) self.model.log( - "loss/val/segmentation", + "loss/val/dia", seg_loss, on_step=False, on_epoch=True, @@ -1037,17 +1029,6 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - loss = seg_loss - - self.model.log( - "loss/val", - loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.model.validation_metric( torch.transpose( multilabel, 1, 2 From 78b5b04197ffd303c6e46d93771144e1badd5730 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 12 Jul 2023 09:38:14 +0200 Subject: [PATCH 35/83] add missing docstrings --- .../speaker_diarization_and_embedding.py | 58 +++++++++++++++---- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 986125153..9fd765076 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -866,49 +866,83 @@ def segmentation_loss( return seg_loss - def compute_diarization_loss(self, dia_chunks_idx, dia_prediction, permutated_target): - """""" + def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target): + """Compute loss for the speaker diarization subtask + + Parameters + ---------- + dia_chunks : torch.Tensor + tensor specifying the chunks assigned to the speaker diarization + task in the current batch. Shape of (batch_size,) + dia_prediction : torch.Tensor + speaker diarization output predicted by the model for the current batch. + Shape of (batch_size, num_spk, num_frames) + permutated_target: torch.Tensor + permutated target for the current batch. Shape of (batch_size, num_spk, num_frames) + + Returns + ------- + dia_loss : torch.Tensor + Permutation-invariant diarization loss + """ # Get chunks corresponding to the diarization subtask - diarization_chunks = dia_prediction[dia_chunks_idx] + chunks_prediction = dia_prediction[dia_chunks] # Get the permutated reference corresponding to diarization subtask - permutated_target_dia = permutated_target[dia_chunks_idx] + permutated_target_dia = permutated_target[dia_chunks] # Compute segmentation loss - diarization_loss = self.segmentation_loss(diarization_chunks, + dia_loss = self.segmentation_loss(chunks_prediction, permutated_target_dia) self.model.log( "loss/train/dia", - diarization_loss, + dia_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) - return diarization_loss + return dia_loss def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): - """""" + """Compute loss for the speaker embeddings extraction subtask + + Parameters + ---------- + emb_chunks : torch.Tensor + tensor specifying the chunks assigned to the speaker embeddings extraction + task in the current batch. Shape of (batch_size,) + emb_prediction : torch.Tensor + speaker embeddings predicted by the model for the current batch. + Shape of (batch_size * num_spk, embedding_dim) + target_emb : torch.Tensor + target embeddings for the current batch + Shape of (batch_size * num_spk,) + Returns + ------- + emb_loss : torch.Tensor + arcface loss for the current batch + """ # Get speaker representations from the embedding subtask embeddings = emb_prediction[emb_chunks] # Get corresponding target label targets = target_emb[emb_chunks] # compute the loss - embedding_loss = self.model.arc_face_loss(embeddings, targets) + emb_loss = self.model.arc_face_loss(embeddings, targets) # skip batch if something went wrong for some reason - if torch.isnan(embedding_loss): + if torch.isnan(emb_loss): return None self.model.log( "loss/train/arcface", - embedding_loss, + emb_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) - return embedding_loss + return emb_loss def training_step(self, batch, batch_idx: int): """Compute loss for the joint task From bdf356783526625914d143c1c1eb29a549a1798e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 12 Jul 2023 09:46:09 +0200 Subject: [PATCH 36/83] remove redefinitions of `collate_X` and `collate_meta` these two methods were identical to the methods inherited from the `SegmentationTaskMixin` class --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 9fd765076..7cafcd005 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -709,10 +709,6 @@ def train__iter__(self): # generate random chunk yield next(chunks) - def collate_X(self, batch) -> torch.Tensor: - """Collate for data""" - return default_collate([b["X"] for b in batch]) - def collate_y(self, batch) -> torch.Tensor: """ Parameters @@ -773,10 +769,6 @@ def collate_y(self, batch) -> torch.Tensor: return (torch.from_numpy(np.stack(collated_y_dia)), torch.from_numpy(np.stack(collate_y_emb)).squeeze(1)) - def collate_meta(self, batch) -> torch.Tensor: - """Collate for metadata""" - return default_collate([b["meta"] for b in batch]) - def collate_fn(self, batch, stage="train"): """Collate function used for most segmentation tasks From aae90a0a771fffe5e2e7cbc31d9e73e281a080a7 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 12 Jul 2023 21:40:27 +0200 Subject: [PATCH 37/83] add missing `dia_loss` assignment and fix issue with the loss type during training --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 7cafcd005..fb8b891cf 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -991,14 +991,13 @@ def training_step(self, batch, batch_idx: int): emb_chunks = permutated_target_emb != -1 # get diarization chunks position in current batch (that correspond to non embedding chunks) dia_chunks = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) - #dia_chunks = emb_chunks == False - dia_loss = 0 + dia_loss = torch.tensor(0) #if batch contains diarization subtask chunks, then compute diarization loss on these chunks: if dia_chunks.shape[0] > 0: - self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) + dia_loss = self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) - emb_loss = 0 + emb_loss = torch.tensor(0) # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: if emb_chunks.any(): emb_loss = self.compute_embedding_loss(emb_chunks, emb_prediction, permutated_target_emb) From d3b3efc074799b54fb2ffe13aed1b701492aa63d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 18 Jul 2023 13:22:46 +0200 Subject: [PATCH 38/83] filter out the speaker in ref not found by diarization --- .../speaker_diarization_and_embedding.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index fb8b891cf..d16a07dce 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -888,7 +888,7 @@ def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target self.model.log( "loss/train/dia", dia_loss, - on_step=True, + on_step=False, on_epoch=True, prog_bar=True, logger=True, @@ -929,7 +929,7 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): self.model.log( "loss/train/arcface", emb_loss, - on_step=True, + on_step=False, on_epoch=True, prog_bar=True, logger=True, @@ -972,7 +972,7 @@ def training_step(self, batch, batch_idx: int): # forward pass dia_prediction, emb_prediction = self.model(waveform) - # (batch_size, num_frames, num_spk), (batch_size, num_spk, emb_size) + # (batch_size, num_frames, num_cls), (batch_size, num_spk, emb_size) # get the best permutation dia_multilabel = self.model.powerset.to_multilabel(dia_prediction) @@ -980,8 +980,14 @@ def training_step(self, batch, batch_idx: int): permutated_target_emb = target_emb[torch.arange(target_emb.shape[0]).unsqueeze(1), permut_map] - emb_prediction = rearrange(emb_prediction, "b s e -> (b s) e") - permutated_target_emb = rearrange(permutated_target_emb, "b s -> (b s)") + # filter out the speaker in the reference that were not found by the diarization + # part of the model, to not compute the embedding loss on these speaker: + active_spk_mask = torch.any(rearrange(dia_multilabel, "b f s -> b s f"), dim=2) + # (batch_size, num_spk) + emb_prediction = emb_prediction[active_spk_mask] + # (num_active_spk_found_in_all_the_chunks, emb_size) + permutated_target_emb = permutated_target_emb[active_spk_mask] + # (num_activate_spk_found,) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target_dia.float() From 4289ea9d4594714760d9f3c211411bf6dbf5f472 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 18 Jul 2023 13:24:32 +0200 Subject: [PATCH 39/83] modifiy `start_time` possible values interval in `draw_embedding_chunk` --- .../audio/tasks/joint_task/speaker_diarization_and_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index d16a07dce..78826f114 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -599,7 +599,7 @@ class ID in the task speficiations segment = np.random.choice(class_segments, p=prob_segments) # sample chunk start time in order to intersect it with the sampled segment - start_time = np.random.uniform(max(segment["start"] - duration / 2, 0), segment["start"]) + start_time = np.random.uniform(max(segment["start"] - duration, 0), segment["end"]) return (segment["file_id"], start_time) From e9f40a3e9920ac34e5f99bf60f405c73f8548d3d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 19 Jul 2023 22:51:05 +0200 Subject: [PATCH 40/83] add V2 of `SpeakerEndToEndDiarization` this version replace `StatsPool` by a concatenation of the last outputs of TDNN (for the embedding part) and LSTM (for the diarization part) and a LSTM layer --- pyannote/audio/models/joint/__init__.py | 4 +- .../models/joint/end_to_end_diarization.py | 208 ++++++++++++++++++ 2 files changed, 210 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/models/joint/__init__.py b/pyannote/audio/models/joint/__init__.py index 3c6230fdb..f32ef8d98 100644 --- a/pyannote/audio/models/joint/__init__.py +++ b/pyannote/audio/models/joint/__init__.py @@ -20,6 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from .end_to_end_diarization import SpeakerEndToEndDiarization +from .end_to_end_diarization import SpeakerEndToEndDiarization, SpeakerEndToEndDiarizationV2 -__all__ = ["SpeakerEndToEndDiarization"] +__all__ = ["SpeakerEndToEndDiarization", "SpeakerEndToEndDiarizationV2"] diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index 031b1fce9..14c565502 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -232,3 +232,211 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non embedding_outputs = self.embedding(embedding_outputs) return (diarization_outputs, embedding_outputs) + +class SpeakerEndToEndDiarizationV2(Model): + + SINCNET_DEFAULTS = {"stride": 10} + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 2, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + "batch_first": True, + } + LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + sincnet: dict = None, + lstm: dict= None, + linear: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + num_features: int = 60, + embedding_dim: int = 512, + separation_idx: int = 2, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if num_features != 60: + warn("For now, the model only support a number of features of 60. Set it to 60") + num_features = 60 + self.num_features = num_features + self.separation_idx = separation_idx + self.save_hyperparameters("num_features", "embedding_dim", "separation_idx") + + + # sincnet module + sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) + sincnet["sample_rate"] = sample_rate + self.sincnet =SincNet(**sincnet) + self.save_hyperparameters("sincnet") + + # tdnn modules + self.tdnn_blocks = nn.ModuleList() + in_channel = num_features + out_channels = [512, 512, 512, 512, 1500] + kernel_sizes = [5, 3, 3, 1, 1] + dilations = [1, 2, 3, 1, 1] + self.last_tdnn_out_channels = out_channels[-1] + + for out_channel, kernel_size, dilation in zip( + out_channels, kernel_sizes, dilations + ): + self.tdnn_blocks.extend( + [ + nn.Sequential( + nn.Conv1d( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + dilation=dilation, + ), + nn.LeakyReLU(), + nn.BatchNorm1d(out_channel), + ), + ] + ) + in_channel = out_channel + + # lstm modules: + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + self.save_hyperparameters("lstm") + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(out_channels[separation_idx], **multi_layer_lstm) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + del one_layer_lstm["monolithic"] + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + out_channels[separation_idx] + if i == 0 + else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + **one_layer_lstm + ) + for i in range(num_layers) + ] + ) + + # linear module for the diarization part: + linear = merge_dict(self.LINEAR_DEFAULTS, linear) + self.save_hyperparameters("linear") + if linear["num_layers"] < 1: + return + + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + # linear module for the embedding part: + self.embedding = nn.Linear(self.hparams.lstm["hidden_size"], embedding_dim) + + def build(self): + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + in_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + + diarization_spec = self.specifications[Subtasks.index("diarization")] + out_features = diarization_spec.num_powerset_classes + self.classifier = nn.Linear(in_features, out_features) + + self.powerset = Powerset( + len(diarization_spec.classes), + diarization_spec.powerset_max_classes, + ) + + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.encoder = nn.LSTM( + # number of channel in the outputs of the last TDNN layer + lstm_out_features + input_size= self.last_tdnn_out_channels + lstm_out_features, + hidden_size= len(diarization_spec.classes * self.hparams.lstm["hidden_size"]), + batch_first=True, + bidirectional=False, + ) + + def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + + Parameters + ---------- + waveforms : torch.Tensor + Batch of waveforms with shape (batch, channel, sample) + weights : torch.Tensor, optional + Batch of weights wiht shape (batch, frame) + """ + common_outputs = self.sincnet(waveforms) + # (batch, features, frames) + # common part to diarization and embedding: + tdnn_idx = 0 + while tdnn_idx <= self.separation_idx: + common_outputs = self.tdnn_blocks[tdnn_idx](common_outputs) + tdnn_idx = tdnn_idx + 1 + # diarization part: + dia_outputs = common_outputs + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm( + rearrange(dia_outputs, "batch feature frame -> batch frame feature") + ) + else: + dia_outputs = rearrange(common_outputs, "batch feature frame -> batch frame feature") + for i, lstm in enumerate(self.lstm): + dia_outputs, _ = lstm(dia_outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + dia_outputs = self.linear() + lstm_outputs = dia_outputs + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + dia_outputs = F.leaky_relu(linear(dia_outputs)) + dia_outputs = self.classifier(dia_outputs) + dia_outputs = F.log_softmax(dia_outputs, dim=-1) + + # embedding part: + emb_outputs = common_outputs + for tdnn_block in self.tdnn_blocks[tdnn_idx:]: + emb_outputs = tdnn_block(emb_outputs) + + # there is a change in the number of frames in the embeddings section compared with the + # diarization section, due to the application of kernels in the last tdnn layers after separation: + emb_outputs = rearrange(emb_outputs, "b c f -> b f c") + frame_dim_diff = lstm_outputs.shape[1] - emb_outputs.shape[1] + if frame_dim_diff != 0: + lstm_outputs = lstm_outputs[:, frame_dim_diff // 2 : -(frame_dim_diff // 2), :] + # Concatenation of last tdnn layer outputs with the last diarization lstm outputs: + emb_outputs = torch.cat((emb_outputs, lstm_outputs), dim=2) + _, emb_outputs = self.encoder(emb_outputs) + emb_outputs = emb_outputs[0].squeeze(0) + emb_outputs = torch.reshape(emb_outputs, (emb_outputs.shape[0], self.powerset.num_classes, -1)) + emb_outputs = self.embedding(emb_outputs) + + return (dia_outputs, emb_outputs) From 3f7cb8a188f1858fb9bc4fc6ee8687ca769cb704 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 26 Jul 2023 08:52:21 +0200 Subject: [PATCH 41/83] Add `padding="same"` in model `Conv1d` layers --- pyannote/audio/models/joint/end_to_end_diarization.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index 14c565502..c21307ac4 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -104,6 +104,7 @@ def __init__( out_channels=out_channel, kernel_size=kernel_size, dilation=dilation, + padding="same", ), nn.LeakyReLU(), nn.BatchNorm1d(out_channel), @@ -169,8 +170,6 @@ def __init__( # linear module for the embedding part: self.embedding = nn.Linear(in_channel * 2, embedding_dim) - - def build(self): if self.hparams.linear["num_layers"] > 0: in_features = self.hparams.linear["hidden_size"] @@ -233,6 +232,7 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non return (diarization_outputs, embedding_outputs) + class SpeakerEndToEndDiarizationV2(Model): SINCNET_DEFAULTS = {"stride": 10} @@ -293,6 +293,7 @@ def __init__( out_channels=out_channel, kernel_size=kernel_size, dilation=dilation, + padding="same" ), nn.LeakyReLU(), nn.BatchNorm1d(out_channel), @@ -426,12 +427,7 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non for tdnn_block in self.tdnn_blocks[tdnn_idx:]: emb_outputs = tdnn_block(emb_outputs) - # there is a change in the number of frames in the embeddings section compared with the - # diarization section, due to the application of kernels in the last tdnn layers after separation: emb_outputs = rearrange(emb_outputs, "b c f -> b f c") - frame_dim_diff = lstm_outputs.shape[1] - emb_outputs.shape[1] - if frame_dim_diff != 0: - lstm_outputs = lstm_outputs[:, frame_dim_diff // 2 : -(frame_dim_diff // 2), :] # Concatenation of last tdnn layer outputs with the last diarization lstm outputs: emb_outputs = torch.cat((emb_outputs, lstm_outputs), dim=2) _, emb_outputs = self.encoder(emb_outputs) From 0f1577d6cb2c63e64e02dad011396ee7a0c32cc5 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 26 Jul 2023 08:56:20 +0200 Subject: [PATCH 42/83] update LSTM encoder in SPEED V2 Now, this LSTM is bidirectionnal and has a hidden size of 1500, so the outputs shape of this encoder is (b, s, 1500*2). This will allow comparing with `StatsPool` version of the SPEED model --- .../audio/models/joint/end_to_end_diarization.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index c21307ac4..ec8dbcf88 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -234,6 +234,7 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non class SpeakerEndToEndDiarizationV2(Model): + """This version uses a LSTM encoder in the embedding branch instead StatsPool block""" SINCNET_DEFAULTS = {"stride": 10} LSTM_DEFAULTS = { @@ -355,7 +356,7 @@ def __init__( ) # linear module for the embedding part: - self.embedding = nn.Linear(self.hparams.lstm["hidden_size"], embedding_dim) + self.embedding = nn.Linear(2 * self.last_tdnn_out_channels, embedding_dim) def build(self): if self.hparams.linear["num_layers"] > 0: @@ -380,9 +381,9 @@ def build(self): self.encoder = nn.LSTM( # number of channel in the outputs of the last TDNN layer + lstm_out_features input_size= self.last_tdnn_out_channels + lstm_out_features, - hidden_size= len(diarization_spec.classes * self.hparams.lstm["hidden_size"]), + hidden_size= len(diarization_spec.classes) * self.last_tdnn_out_channels, batch_first=True, - bidirectional=False, + bidirectional=True, ) def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -431,8 +432,9 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non # Concatenation of last tdnn layer outputs with the last diarization lstm outputs: emb_outputs = torch.cat((emb_outputs, lstm_outputs), dim=2) _, emb_outputs = self.encoder(emb_outputs) - emb_outputs = emb_outputs[0].squeeze(0) - emb_outputs = torch.reshape(emb_outputs, (emb_outputs.shape[0], self.powerset.num_classes, -1)) + emb_outputs = rearrange(emb_outputs[0], "l b h -> b (l h)") + emb_outputs = torch.reshape(emb_outputs, + (emb_outputs.shape[0], self.powerset.num_classes, -1)) emb_outputs = self.embedding(emb_outputs) return (dia_outputs, emb_outputs) From 933a66008987017d34e18920798bf4c8c160276a Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 13 Oct 2023 15:51:23 +0200 Subject: [PATCH 43/83] add `prepare_data` method in `Task` class The goal of this method is to generate the data needed by the task and save it on disk for future uses, for example by the `setup` method. The objective is to avoid systematically recreating data on each process at the beginning of a training --- pyannote/audio/core/model.py | 5 +- pyannote/audio/core/task.py | 350 +++++++++++++++++++- pyannote/audio/tasks/segmentation/mixins.py | 336 ++----------------- 3 files changed, 363 insertions(+), 328 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index bedb7f6c4..7bec676b4 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -217,9 +217,12 @@ def __example_output( self.specifications, __example_output, example_output ) + def prepare_data(self): + self.task.prepare_data() + def setup(self, stage=None): if stage == "fit": - self.task.setup_metadata() + self.task.setup() # list of layers before adding task-dependent layers before = set((name, id(module)) for name, module in self.named_modules()) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 1edfbc35c..fdabc3ba5 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -23,11 +23,16 @@ from __future__ import annotations +import itertools import multiprocessing +import numpy as np +import os +import pickle import sys import warnings from dataclasses import dataclass from enum import Enum +from collections import defaultdict from functools import cached_property, partial from numbers import Number from typing import Dict, List, Literal, Optional, Sequence, Text, Tuple, Union @@ -36,6 +41,8 @@ import scipy.special import torch from pyannote.database import Protocol +from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol +from pyannote.database.protocol.protocol import Scope, Subset from torch.utils.data import DataLoader, Dataset, IterableDataset from torch_audiomentations import Identity from torch_audiomentations.core.transforms_interface import BaseWaveformTransform @@ -44,6 +51,8 @@ from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss from pyannote.audio.utils.protocol import check_protocol +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) # Type of machine learning problem class Problem(Enum): @@ -196,6 +205,8 @@ class Task(pl.LightningDataModule): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to value returned by `default_metric` method. + cache_path : str, optional + File path where store task-related data, especially data from protocol Attributes ---------- @@ -214,6 +225,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path = "./cache/task_cache/protocle_data.pickle" ): super().__init__() @@ -253,6 +265,7 @@ def __init__( self.pin_memory = pin_memory self.augmentation = augmentation or Identity(output_type="dict") self._metric = metric + self.cache_path = cache_path def prepare_data(self): """Use this to download and prepare data @@ -266,7 +279,327 @@ def prepare_data(self): ----- Called only once. """ - pass + + if os.path.exists(self.cache_path): + # data was already created, do nothing + return + else: + #create the repo + os.makedirs(self.cache_path[:self.cache_path.rfind('/')]) + # duration of training chunks + # TODO: handle variable duration case + duration = getattr(self, "duration", 0.0) + + # list of possible values for each metadata key + metadata_unique_values = defaultdict(list) + + metadata_unique_values["subset"] = Subsets + + if isinstance(self.protocol, SpeakerDiarizationProtocol): + metadata_unique_values["scope"] = Scopes + + elif isinstance(self.protocol, SegmentationProtocol): + classes = getattr(self, "classes", list()) + + # make sure classes attribute exists (and set to None if it did not exist) + self.classes = getattr(self, "classes", None) + if self.classes is None: + classes = list() + # metadata_unique_values["classes"] = list(classes) + + audios = list() # list of path to audio files + audio_infos = list() + audio_encodings = list() + metadata = list() # list of metadata + + annotated_duration = list() # total duration of annotated regions (per file) + annotated_regions = list() # annotated regions + annotations = list() # actual annotations + annotated_classes = list() # list of annotated classes (per file) + unique_labels = list() + + if self.has_validation: + files_iter = itertools.chain( + self.protocol.train(), self.protocol.development() + ) + else: + files_iter = self.protocol.train() + + for file_id, file in enumerate(files_iter): + # gather metadata and update metadata_unique_values so that each metadatum + # (e.g. source database or label) is represented by an integer. + metadatum = dict() + + # keep track of source database and subset (train, development, or test) + if file["database"] not in metadata_unique_values["database"]: + metadata_unique_values["database"].append(file["database"]) + metadatum["database"] = metadata_unique_values["database"].index( + file["database"] + ) + metadatum["subset"] = Subsets.index(file["subset"]) + + # keep track of speaker label scope (file, database, or global) for speaker diarization protocols + if isinstance(self.protocol, SpeakerDiarizationProtocol): + metadatum["scope"] = Scopes.index(file["scope"]) + + # keep track of list of classes for regular segmentation protocols + # Different files may be annotated using a different set of classes + # (e.g. one database for speech/music/noise, and another one for male/female/child) + if isinstance(self.protocol, SegmentationProtocol): + if "classes" in file: + local_classes = file["classes"] + else: + local_classes = file["annotation"].labels() + + # if task was not initialized with a fixed list of classes, + # we build it as the union of all classes found in files + if self.classes is None: + for klass in local_classes: + if klass not in classes: + classes.append(klass) + annotated_classes.append( + [classes.index(klass) for klass in local_classes] + ) + + # if task was initialized with a fixed list of classes, + # we make sure that all files use a subset of these classes + # if they don't, we issue a warning and ignore the extra classes + else: + extra_classes = set(local_classes) - set(self.classes) + if extra_classes: + warnings.warn( + f"Ignoring extra classes ({', '.join(extra_classes)}) found for file {file['uri']} ({file['database']}). " + ) + annotated_classes.append( + [ + self.classes.index(klass) + for klass in set(local_classes) & set(self.classes) + ] + ) + + remaining_metadata_keys = set(file) - set( + [ + "uri", + "database", + "subset", + "audio", + "torchaudio.info", + "scope", + "classes", + "annotation", + "annotated", + ] + ) + + # keep track of any other (integer or string) metadata provided by the protocol + # (e.g. a "domain" key for domain-adversarial training) + for key in remaining_metadata_keys: + value = file[key] + + if isinstance(value, str): + if value not in metadata_unique_values[key]: + metadata_unique_values[key].append(value) + metadatum[key] = metadata_unique_values[key].index(value) + + elif isinstance(value, int): + metadatum[key] = value + + else: + warnings.warn( + f"Ignoring '{key}' metadata because of its type ({type(value)}). Only str and int are supported for now.", + category=UserWarning, + ) + + metadata.append(metadatum) + + database_unique_labels = list() + + # reset list of file-scoped labels + file_unique_labels = list() + + # path to audio file + audios.append(str(file["audio"])) + + # audio info + audio_info = file["torchaudio.info"] + audio_infos.append( + ( + audio_info.sample_rate, # sample rate + audio_info.num_frames, # number of frames + audio_info.num_channels, # number of channels + audio_info.bits_per_sample, # bits per sample + ) + ) + audio_encodings.append(audio_info.encoding) # encoding + + # annotated regions and duration + _annotated_duration = 0.0 + for segment in file["annotated"]: + # skip annotated regions that are shorter than training chunk duration + if segment.duration < duration: + continue + + # append annotated region + annotated_region = ( + file_id, + segment.duration, + segment.start, + segment.end, + ) + annotated_regions.append(annotated_region) + + # increment annotated duration + _annotated_duration += segment.duration + + # append annotated duration + annotated_duration.append(_annotated_duration) + + # annotations + for segment, _, label in file["annotation"].itertracks(yield_label=True): + # "scope" is provided by speaker diarization protocols to indicate + # whether speaker labels are local to the file ('file'), consistent across + # all files in a database ('database'), or globally consistent ('global') + + if "scope" in file: + # 0 = 'file' + # 1 = 'database' + # 2 = 'global' + scope = Scopes.index(file["scope"]) + + # update list of file-scope labels + if label not in file_unique_labels: + file_unique_labels.append(label) + # and convert label to its (file-scope) index + file_label_idx = file_unique_labels.index(label) + + database_label_idx = global_label_idx = -1 + + if scope > 0: # 'database' or 'global' + # update list of database-scope labels + if label not in database_unique_labels: + database_unique_labels.append(label) + + # and convert label to its (database-scope) index + database_label_idx = database_unique_labels.index(label) + + if scope > 1: # 'global' + # update list of global-scope labels + if label not in unique_labels: + unique_labels.append(label) + # and convert label to its (global-scope) index + global_label_idx = unique_labels.index(label) + + # basic segmentation protocols do not provide "scope" information + # as classes are global by definition + + else: + try: + file_label_idx = ( + database_label_idx + ) = global_label_idx = classes.index(label) + except ValueError: + # skip labels that are not in the list of classes + continue + + annotations.append( + ( + file_id, # index of file + segment.start, # start time + segment.end, # end time + file_label_idx, # file-scope label index + database_label_idx, # database-scope label index + global_label_idx, # global-scope index + ) + ) + + # since not all metadata keys are present in all files, fallback to -1 when a key is missing + metadata = [ + tuple(metadatum.get(key, -1) for key in metadata_unique_values) + for metadatum in metadata + ] + dtype = [(key, "i") for key in metadata_unique_values] + + # save all protocol data in a dict + data_dict = {} + data_dict["metadata"] = np.array(metadata, dtype=dtype) + data_dict["audios"] = np.array(audios, dtype=np.string_) + + # turn list of files metadata into a single numpy array + # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 + dtype = [ + ("sample_rate", "i"), + ("num_frames", "i"), + ("num_channels", "i"), + ("bits_per_sample", "i"), + ] + data_dict["audio_infos"] = np.array(audio_infos, dtype=dtype) + data_dict["audio_encodings"] = np.array(audio_encodings, dtype=np.string_) + data_dict["annotated_duration"] = np.array(annotated_duration) + + # turn list of annotated regions into a single numpy array + dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] + annotated_regions_array = np.array(annotated_regions, dtype=dtype) + data_dict["annotated_regions"] = annotated_regions_array + + # convert annotated_classes (which is a list of list of classes, one list of classes per file) + # into a single (num_files x num_classes) numpy array: + # * True indicates that this particular class was annotated for this particular file (though it may not be active in this file) + # * False indicates that this particular class was not even annotated (i.e. its absence does not imply that it is not active in this file) + if isinstance(self.protocol, SegmentationProtocol) and self.classes is None: + data_dict["classes"] = classes + annotated_classes_array = np.zeros( + (len(annotated_classes), len(classes)), dtype=np.bool_ + ) + for file_id, classes in enumerate(annotated_classes): + annotated_classes_array[file_id, classes] = True + data_dict["annotated_classes"] = annotated_classes_array + + # turn list of annotations into a single numpy array + dtype = [ + ("file_id", "i"), + ("start", "f"), + ("end", "f"), + ("file_label_idx", "i"), + ("database_label_idx", "i"), + ("global_label_idx", "i"), + ] + + data_dict["annotations"] = np.array(annotations, dtype=dtype) + data_dict["metadata_unique_values"] = metadata_unique_values + + if not self.has_validation: + return + + validation_chunks = list() + + # obtain indexes of files in the validation subset + validation_file_ids = np.where( + data_dict["metadata"]["subset"] == Subsets.index("development") + )[0] + + # iterate over files in the validation subset + for file_id in validation_file_ids: + # get annotated regions in file + annotated_regions = annotated_regions_array[ + annotated_regions_array["file_id"] == file_id + ] + + # iterate over annotated regions + for annotated_region in annotated_regions: + # number of chunks in annotated region + num_chunks = round(annotated_region["duration"] // duration) + + # iterate over chunks + for c in range(num_chunks): + start_time = annotated_region["start"] + c * duration + validation_chunks.append((file_id, start_time, duration)) + + dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] + data_dict["validation_chunks"] = np.array(validation_chunks, dtype=dtype) + + # cache generated protocol data on disk + with open(self.cache_path, 'wb') as data_file: + pickle.dump(data_dict, data_file) @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: @@ -289,21 +622,6 @@ def has_setup_metadata(self): def has_setup_metadata(self, value: bool): self._has_setup_metadata = value - def setup_metadata(self): - """Called at the beginning of training at the very beginning of Model.setup(stage="fit") - - Notes - ----- - This hook is called on every process when using DDP. - - If `specifications` attribute has not been set in `__init__`, - `setup` is your last chance to set it. - """ - - if not self.has_setup_metadata: - self.setup() - self.has_setup_metadata = True - def setup_loss_func(self): pass diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 018e8db70..2ccc6a8be 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -30,6 +30,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +import pickle from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -73,317 +74,30 @@ def get_file(self, file_id): def setup(self): """Setup""" - - # duration of training chunks - # TODO: handle variable duration case - duration = getattr(self, "duration", 0.0) - - # list of possible values for each metadata key - metadata_unique_values = defaultdict(list) - - metadata_unique_values["subset"] = Subsets - - if isinstance(self.protocol, SpeakerDiarizationProtocol): - metadata_unique_values["scope"] = Scopes - - elif isinstance(self.protocol, SegmentationProtocol): - classes = getattr(self, "classes", list()) - - # make sure classes attribute exists (and set to None if it did not exist) - self.classes = getattr(self, "classes", None) - if self.classes is None: - classes = list() - # metadata_unique_values["classes"] = list(classes) - - audios = list() # list of path to audio files - audio_infos = list() - audio_encodings = list() - metadata = list() # list of metadata - - annotated_duration = list() # total duration of annotated regions (per file) - annotated_regions = list() # annotated regions - annotations = list() # actual annotations - annotated_classes = list() # list of annotated classes (per file) - unique_labels = list() - - if self.has_validation: - files_iter = itertools.chain( - self.protocol.train(), self.protocol.development() - ) - else: - files_iter = self.protocol.train() - - for file_id, file in enumerate(files_iter): - # gather metadata and update metadata_unique_values so that each metadatum - # (e.g. source database or label) is represented by an integer. - metadatum = dict() - - # keep track of source database and subset (train, development, or test) - if file["database"] not in metadata_unique_values["database"]: - metadata_unique_values["database"].append(file["database"]) - metadatum["database"] = metadata_unique_values["database"].index( - file["database"] - ) - metadatum["subset"] = Subsets.index(file["subset"]) - - # keep track of speaker label scope (file, database, or global) for speaker diarization protocols - if isinstance(self.protocol, SpeakerDiarizationProtocol): - metadatum["scope"] = Scopes.index(file["scope"]) - - # keep track of list of classes for regular segmentation protocols - # Different files may be annotated using a different set of classes - # (e.g. one database for speech/music/noise, and another one for male/female/child) - if isinstance(self.protocol, SegmentationProtocol): - if "classes" in file: - local_classes = file["classes"] - else: - local_classes = file["annotation"].labels() - - # if task was not initialized with a fixed list of classes, - # we build it as the union of all classes found in files - if self.classes is None: - for klass in local_classes: - if klass not in classes: - classes.append(klass) - annotated_classes.append( - [classes.index(klass) for klass in local_classes] - ) - - # if task was initialized with a fixed list of classes, - # we make sure that all files use a subset of these classes - # if they don't, we issue a warning and ignore the extra classes - else: - extra_classes = set(local_classes) - set(self.classes) - if extra_classes: - warnings.warn( - f"Ignoring extra classes ({', '.join(extra_classes)}) found for file {file['uri']} ({file['database']}). " - ) - annotated_classes.append( - [ - self.classes.index(klass) - for klass in set(local_classes) & set(self.classes) - ] - ) - - remaining_metadata_keys = set(file) - set( - [ - "uri", - "database", - "subset", - "audio", - "torchaudio.info", - "scope", - "classes", - "annotation", - "annotated", - ] - ) - - # keep track of any other (integer or string) metadata provided by the protocol - # (e.g. a "domain" key for domain-adversarial training) - for key in remaining_metadata_keys: - value = file[key] - - if isinstance(value, str): - if value not in metadata_unique_values[key]: - metadata_unique_values[key].append(value) - metadatum[key] = metadata_unique_values[key].index(value) - - elif isinstance(value, int): - metadatum[key] = value - - else: - warnings.warn( - f"Ignoring '{key}' metadata because of its type ({type(value)}). Only str and int are supported for now.", - category=UserWarning, - ) - - metadata.append(metadatum) - - database_unique_labels = list() - - # reset list of file-scoped labels - file_unique_labels = list() - - # path to audio file - audios.append(str(file["audio"])) - - # audio info - audio_info = file["torchaudio.info"] - audio_infos.append( - ( - audio_info.sample_rate, # sample rate - audio_info.num_frames, # number of frames - audio_info.num_channels, # number of channels - audio_info.bits_per_sample, # bits per sample - ) - ) - audio_encodings.append(audio_info.encoding) # encoding - - # annotated regions and duration - _annotated_duration = 0.0 - for segment in file["annotated"]: - # skip annotated regions that are shorter than training chunk duration - if segment.duration < duration: - continue - - # append annotated region - annotated_region = ( - file_id, - segment.duration, - segment.start, - segment.end, - ) - annotated_regions.append(annotated_region) - - # increment annotated duration - _annotated_duration += segment.duration - - # append annotated duration - annotated_duration.append(_annotated_duration) - - # annotations - for segment, _, label in file["annotation"].itertracks(yield_label=True): - # "scope" is provided by speaker diarization protocols to indicate - # whether speaker labels are local to the file ('file'), consistent across - # all files in a database ('database'), or globally consistent ('global') - - if "scope" in file: - # 0 = 'file' - # 1 = 'database' - # 2 = 'global' - scope = Scopes.index(file["scope"]) - - # update list of file-scope labels - if label not in file_unique_labels: - file_unique_labels.append(label) - # and convert label to its (file-scope) index - file_label_idx = file_unique_labels.index(label) - - database_label_idx = global_label_idx = -1 - - if scope > 0: # 'database' or 'global' - # update list of database-scope labels - if label not in database_unique_labels: - database_unique_labels.append(label) - - # and convert label to its (database-scope) index - database_label_idx = database_unique_labels.index(label) - - if scope > 1: # 'global' - # update list of global-scope labels - if label not in unique_labels: - unique_labels.append(label) - # and convert label to its (global-scope) index - global_label_idx = unique_labels.index(label) - - # basic segmentation protocols do not provide "scope" information - # as classes are global by definition - - else: - try: - file_label_idx = ( - database_label_idx - ) = global_label_idx = classes.index(label) - except ValueError: - # skip labels that are not in the list of classes - continue - - annotations.append( - ( - file_id, # index of file - segment.start, # start time - segment.end, # end time - file_label_idx, # file-scope label index - database_label_idx, # database-scope label index - global_label_idx, # global-scope index - ) - ) - - # since not all metadata keys are present in all files, fallback to -1 when a key is missing - metadata = [ - tuple(metadatum.get(key, -1) for key in metadata_unique_values) - for metadatum in metadata - ] - dtype = [(key, "i") for key in metadata_unique_values] - self.metadata = np.array(metadata, dtype=dtype) - - # NOTE: read with str(self.audios[file_id], encoding='utf-8') - self.audios = np.array(audios, dtype=np.string_) - - # turn list of files metadata into a single numpy array - # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 - - dtype = [ - ("sample_rate", "i"), - ("num_frames", "i"), - ("num_channels", "i"), - ("bits_per_sample", "i"), - ] - self.audio_infos = np.array(audio_infos, dtype=dtype) - self.audio_encodings = np.array(audio_encodings, dtype=np.string_) - - self.annotated_duration = np.array(annotated_duration) - - # turn list of annotated regions into a single numpy array - dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] - self.annotated_regions = np.array(annotated_regions, dtype=dtype) - - # convert annotated_classes (which is a list of list of classes, one list of classes per file) - # into a single (num_files x num_classes) numpy array: - # * True indicates that this particular class was annotated for this particular file (though it may not be active in this file) - # * False indicates that this particular class was not even annotated (i.e. its absence does not imply that it is not active in this file) - if isinstance(self.protocol, SegmentationProtocol) and self.classes is None: - self.classes = classes - self.annotated_classes = np.zeros( - (len(annotated_classes), len(self.classes)), dtype=np.bool_ - ) - for file_id, classes in enumerate(annotated_classes): - self.annotated_classes[file_id, classes] = True - - # turn list of annotations into a single numpy array - dtype = [ - ("file_id", "i"), - ("start", "f"), - ("end", "f"), - ("file_label_idx", "i"), - ("database_label_idx", "i"), - ("global_label_idx", "i"), - ] - self.annotations = np.array(annotations, dtype=dtype) - - self.metadata_unique_values = metadata_unique_values - - if not self.has_validation: - return - - validation_chunks = list() - - # obtain indexes of files in the validation subset - validation_file_ids = np.where( - self.metadata["subset"] == Subsets.index("development") - )[0] - - # iterate over files in the validation subset - for file_id in validation_file_ids: - # get annotated regions in file - annotated_regions = self.annotated_regions[ - self.annotated_regions["file_id"] == file_id - ] - - # iterate over annotated regions - for annotated_region in annotated_regions: - # number of chunks in annotated region - num_chunks = round(annotated_region["duration"] // duration) - - # iterate over chunks - for c in range(num_chunks): - start_time = annotated_region["start"] + c * duration - validation_chunks.append((file_id, start_time, duration)) - - dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] - self.validation_chunks = np.array(validation_chunks, dtype=dtype) - + if not self.has_setup_metadata: + # load data cached by prepare_data method into the task: + try: + with open(self.cache_path, 'rb') as data_file: + data_dict = pickle.load(data_file) + self.metadata = data_dict["metadata"] + self.audios = data_dict["audios"] + self.audio_infos= data_dict["audio_infos"] + self.audio_encodings = data_dict["audio_encodings"] + self.annotated_duration = data_dict["annotated_duration"] + self.annotated_regions = data_dict["annotated_regions"] + self.annotated_classes = data_dict["annotated_classes"] + self.annotations = data_dict["annotations"] + self.metadata_unique_values = data_dict["metadata_unique_values"] + if isinstance(self.protocol, SegmentationProtocol): + self.classes = data_dict["classes"] + if self.has_validation: + self.validation_chunks = data_dict["validation_chunks"] + except FileNotFoundError: + print("Cached data for protocol not found. Ensure that prepare_data was \ + executed correctly and that the path to the task cache is correct") + raise + self.has_setup_metadata = True + def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: From 882957404376672a5b4cafe8bc81c48e7dbaaa1a Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 2 Nov 2023 14:33:33 +0100 Subject: [PATCH 44/83] modify organisation of `pyannote` segmentation tasks Now all the segmentations tasks in `pyannote` inherit the `SegmentationTask` (previously `SegmentationTaskMixin`), which inherits the `Task` class. This commit also adds a `prepared_data` attribute to the `Task` class. That attribute is a dict which contains all the prepared data by the `prepare_data` method. --- pyannote/audio/core/task.py | 63 ++++++++++++------- pyannote/audio/tasks/segmentation/mixins.py | 62 +++++------------- .../audio/tasks/segmentation/multilabel.py | 16 ++--- .../overlapped_speech_detection.py | 10 +-- .../tasks/segmentation/speaker_diarization.py | 22 +++---- .../segmentation/voice_activity_detection.py | 10 +-- 6 files changed, 84 insertions(+), 99 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index fdabc3ba5..c65a98a30 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -54,6 +54,7 @@ Subsets = list(Subset.__args__) Scopes = list(Scope.__args__) + # Type of machine learning problem class Problem(Enum): BINARY_CLASSIFICATION = 0 @@ -225,7 +226,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path = "./cache/task_cache/protocle_data.pickle" + cache_path="./cache/task_cache/prepared_data.pickle" ): super().__init__() @@ -266,6 +267,7 @@ def __init__( self.augmentation = augmentation or Identity(output_type="dict") self._metric = metric self.cache_path = cache_path + self.prepared_data = {} def prepare_data(self): """Use this to download and prepare data @@ -284,8 +286,8 @@ def prepare_data(self): # data was already created, do nothing return else: - #create the repo - os.makedirs(self.cache_path[:self.cache_path.rfind('/')]) + # create the repo + os.makedirs(self.cache_path[:self.cache_path.rfind('/')], exist_ok=True) # duration of training chunks # TODO: handle variable duration case duration = getattr(self, "duration", 0.0) @@ -518,11 +520,11 @@ def prepare_data(self): for metadatum in metadata ] dtype = [(key, "i") for key in metadata_unique_values] - + # save all protocol data in a dict - data_dict = {} - data_dict["metadata"] = np.array(metadata, dtype=dtype) - data_dict["audios"] = np.array(audios, dtype=np.string_) + prepared_data = {} + prepared_data["metadata"] = np.array(metadata, dtype=dtype) + prepared_data["audios"] = np.array(audios, dtype=np.string_) # turn list of files metadata into a single numpy array # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 @@ -532,27 +534,27 @@ def prepare_data(self): ("num_channels", "i"), ("bits_per_sample", "i"), ] - data_dict["audio_infos"] = np.array(audio_infos, dtype=dtype) - data_dict["audio_encodings"] = np.array(audio_encodings, dtype=np.string_) - data_dict["annotated_duration"] = np.array(annotated_duration) - + prepared_data["audio_infos"] = np.array(audio_infos, dtype=dtype) + prepared_data["audio_encodings"] = np.array(audio_encodings, dtype=np.string_) + prepared_data["annotated_duration"] = np.array(annotated_duration) + # turn list of annotated regions into a single numpy array dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] annotated_regions_array = np.array(annotated_regions, dtype=dtype) - data_dict["annotated_regions"] = annotated_regions_array - + prepared_data["annotated_regions"] = annotated_regions_array + # convert annotated_classes (which is a list of list of classes, one list of classes per file) # into a single (num_files x num_classes) numpy array: # * True indicates that this particular class was annotated for this particular file (though it may not be active in this file) # * False indicates that this particular class was not even annotated (i.e. its absence does not imply that it is not active in this file) if isinstance(self.protocol, SegmentationProtocol) and self.classes is None: - data_dict["classes"] = classes + prepared_data["classes"] = classes annotated_classes_array = np.zeros( (len(annotated_classes), len(classes)), dtype=np.bool_ ) for file_id, classes in enumerate(annotated_classes): annotated_classes_array[file_id, classes] = True - data_dict["annotated_classes"] = annotated_classes_array + prepared_data["annotated_classes"] = annotated_classes_array # turn list of annotations into a single numpy array dtype = [ @@ -564,17 +566,17 @@ def prepare_data(self): ("global_label_idx", "i"), ] - data_dict["annotations"] = np.array(annotations, dtype=dtype) - data_dict["metadata_unique_values"] = metadata_unique_values - + prepared_data["annotations"] = np.array(annotations, dtype=dtype) + prepared_data["metadata_unique_values"] = metadata_unique_values + if not self.has_validation: return validation_chunks = list() - + # obtain indexes of files in the validation subset validation_file_ids = np.where( - data_dict["metadata"]["subset"] == Subsets.index("development") + prepared_data["metadata"]["subset"] == Subsets.index("development") )[0] # iterate over files in the validation subset @@ -595,17 +597,30 @@ def prepare_data(self): validation_chunks.append((file_id, start_time, duration)) dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] - data_dict["validation_chunks"] = np.array(validation_chunks, dtype=dtype) - + prepared_data["validation_chunks"] = np.array(validation_chunks, dtype=dtype) + # cache generated protocol data on disk with open(self.cache_path, 'wb') as data_file: - pickle.dump(data_dict, data_file) + pickle.dump(prepared_data, data_file) + + def setup(self): + """Setup""" + if not self.has_setup_metadata: + # load data cached by prepare_data method into the task: + try: + with open(self.cache_path, 'rb') as data_file: + self.prepared_data = pickle.load(data_file) + except FileNotFoundError: + print("Cached data for protocol not found. Ensure that prepare_data was \ + executed correctly and that the path to the task cache is correct") + raise + self.has_setup_metadata = True @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: # setup metadata on-demand the first time specifications are requested and missing if not hasattr(self, "_specifications"): - self.setup_metadata() + self.setup() return self._specifications @specifications.setter diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 2ccc6a8be..e1c96c18e 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -23,15 +23,11 @@ import itertools import math import random -import warnings -from collections import defaultdict from typing import Dict, Sequence, Union import matplotlib.pyplot as plt import numpy as np import torch -import pickle -from pyannote.database.protocol import SegmentationProtocol, SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger from torch.utils.data._utils.collate import default_collate @@ -39,23 +35,23 @@ from torchmetrics import Metric from torchmetrics.classification import BinaryAUROC, MulticlassAUROC, MultilabelAUROC -from pyannote.audio.core.task import Problem +from pyannote.audio.core.task import Problem, Task from pyannote.audio.utils.random import create_rng_for_worker Subsets = list(Subset.__args__) Scopes = list(Scope.__args__) -class SegmentationTaskMixin: +class SegmentationTask(Task): """Methods common to most segmentation tasks""" def get_file(self, file_id): file = dict() - file["audio"] = str(self.audios[file_id], encoding="utf-8") + file["audio"] = str(self.prepared_data["audios"][file_id], encoding="utf-8") - _audio_info = self.audio_infos[file_id] - _encoding = self.audio_encodings[file_id] + _audio_info = self.prepared_data["audio_infos"][file_id] + _encoding = self.prepared_data["audio_encodings"][file_id] sample_rate = _audio_info["sample_rate"] num_frames = _audio_info["num_frames"] @@ -72,32 +68,6 @@ def get_file(self, file_id): return file - def setup(self): - """Setup""" - if not self.has_setup_metadata: - # load data cached by prepare_data method into the task: - try: - with open(self.cache_path, 'rb') as data_file: - data_dict = pickle.load(data_file) - self.metadata = data_dict["metadata"] - self.audios = data_dict["audios"] - self.audio_infos= data_dict["audio_infos"] - self.audio_encodings = data_dict["audio_encodings"] - self.annotated_duration = data_dict["annotated_duration"] - self.annotated_regions = data_dict["annotated_regions"] - self.annotated_classes = data_dict["annotated_classes"] - self.annotations = data_dict["annotations"] - self.metadata_unique_values = data_dict["metadata_unique_values"] - if isinstance(self.protocol, SegmentationProtocol): - self.classes = data_dict["classes"] - if self.has_validation: - self.validation_chunks = data_dict["validation_chunks"] - except FileNotFoundError: - print("Cached data for protocol not found. Ensure that prepare_data was \ - executed correctly and that the path to the task cache is correct") - raise - self.has_setup_metadata = True - def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: @@ -133,13 +103,13 @@ def train__iter__helper(self, rng: random.Random, **filters): """ # indices of training files that matches domain filters - training = self.metadata["subset"] == Subsets.index("train") + training = self.prepared_data["metadata"]["subset"] == Subsets.index("train") for key, value in filters.items(): - training &= self.metadata[key] == self.metadata_unique_values[key].index(value) + training &= self.prepared_data["metadata"][key] == self.prepared_data["metadata_unique_values"][key].index(value) file_ids = np.where(training)[0] # turn annotated duration into a probability distribution - annotated_duration = self.annotated_duration[file_ids] + annotated_duration = self.prepared_data["annotated_duration"][file_ids] prob_annotated_duration = annotated_duration / np.sum(annotated_duration) duration = self.duration @@ -154,13 +124,13 @@ def train__iter__helper(self, rng: random.Random, **filters): for _ in range(num_chunks_per_file): # find indices of annotated regions in this file annotated_region_indices = np.where( - self.annotated_regions["file_id"] == file_id + self.prepared_data["annotated_regions"]["file_id"] == file_id )[0] # turn annotated regions duration into a probability distribution - prob_annotated_regions_duration = self.annotated_regions["duration"][ + prob_annotated_regions_duration = self.prepared_data["annotated_regions"]["duration"][ annotated_region_indices - ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) + ] / np.sum(self.prepared_data["annotated_regions"]["duration"][annotated_region_indices]) # selected one annotated region at random (with probability proportional to its duration) annotated_region_index = np.random.choice( @@ -168,7 +138,7 @@ def train__iter__helper(self, rng: random.Random, **filters): ) # select one chunk at random in this annotated region - _, _, start, end = self.annotated_regions[annotated_region_index] + _, _, start, end = self.prepared_data["annotated_regions"][annotated_region_index] start_time = rng.uniform(start, end - duration) yield self.prepare_chunk(file_id, start_time, duration) @@ -199,7 +169,7 @@ def train__iter__(self): # create a subchunk generator for each combination of "balance" keys subchunks = dict() for product in itertools.product( - *[self.metadata_unique_values[key] for key in balance] + *[self.prepared_data["metadata_unique_values"][key] for key in balance] ): # we iterate on the cartesian product of the values in metadata_unique_values # eg: for balance=["database", "split"], with 2 databases and 2 splits: @@ -271,11 +241,11 @@ def collate_fn(self, batch, stage="train"): def train__len__(self): # Number of training samples in one epoch - duration = np.sum(self.annotated_duration) + duration = np.sum(self.prepared_data["annotated_duration"]) return max(self.batch_size, math.ceil(duration / self.duration)) def val__getitem__(self, idx): - validation_chunk = self.validation_chunks[idx] + validation_chunk = self.prepared_data["validation_chunks"][idx] return self.prepare_chunk( validation_chunk["file_id"], validation_chunk["start"], @@ -283,7 +253,7 @@ def val__getitem__(self, idx): ) def val__len__(self): - return len(self.validation_chunks) + return len(self.prepared_data["validation_chunks"]) def validation_step(self, batch, batch_idx: int): """Compute validation area under the ROC curve diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index c1d58431a..560b903f1 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -31,11 +31,11 @@ from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.tasks.segmentation.mixins import SegmentationTask -class MultiLabelSegmentation(SegmentationTaskMixin, Task): +class MultiLabelSegmentation(SegmentationTask): """Generic multi-label segmentation Multi-label segmentation is the process of detecting temporal intervals @@ -123,7 +123,7 @@ def setup(self): super().setup() self.specifications = Specifications( - classes=self.classes, + classes=self.prepared_data["classes"], problem=Problem.MULTI_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, @@ -169,7 +169,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.annotations[self.annotations["file_id"] == file_id] + annotations = self.prepared_data["annotations"][self.prepared_data["annotations"]["file_id"] == file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ @@ -184,9 +184,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # frame-level targets (-1 for un-annotated classes) y = -np.ones( - (self.model.example_output.num_frames, len(self.classes)), dtype=np.int8 + (self.model.example_output.num_frames, len(self.prepared_data["classes"])), dtype=np.int8 ) - y[:, self.annotated_classes[file_id]] = 0 + y[:, self.prepared_data["annotated_classes"][file_id]] = 0 for start, end, label in zip( start_idx, end_idx, chunk_annotations["global_label_idx"] ): @@ -196,7 +196,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): y, self.model.example_output.frames, labels=self.classes ) - metadata = self.metadata[file_id] + metadata = self.prepared_data["metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 0b7209c5c..78c0d3c9c 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -29,11 +29,11 @@ from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.tasks.segmentation.mixins import SegmentationTask -class OverlappedSpeechDetection(SegmentationTaskMixin, Task): +class OverlappedSpeechDetection(SegmentationTask): """Overlapped speech detection Overlapped speech detection is the task of detecting regions where at least @@ -163,7 +163,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.annotations[self.annotations["file_id"] == file_id] + annotations = self.prepared_data["annotations"][self.prepared_data["annotations"]["file_id"] == file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ @@ -186,7 +186,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): y, self.model.example_output.frames, labels=["speech"] ) - metadata = self.metadata[file_id] + metadata = self.prepared_data["metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 1094672ed..d42e53a77 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -37,8 +37,8 @@ from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.tasks.segmentation.mixins import SegmentationTask from pyannote.audio.torchmetrics import ( DiarizationErrorRate, FalseAlarmRate, @@ -58,7 +58,7 @@ Scopes = list(Scope.__args__) -class SpeakerDiarization(SegmentationTaskMixin, Task): +class SpeakerDiarization(SegmentationTask): """Speaker diarization Parameters @@ -191,18 +191,18 @@ def setup(self): # estimate maximum number of speakers per chunk when not provided if self.max_speakers_per_chunk is None: - training = self.metadata["subset"] == Subsets.index("train") + training = self.prepared_data["metadata"]["subset"] == Subsets.index("train") num_unique_speakers = [] progress_description = f"Estimating maximum number of speakers per {self.duration:g}s chunk in the training set" for file_id in track( np.where(training)[0], description=progress_description ): - annotations = self.annotations[ - np.where(self.annotations["file_id"] == file_id)[0] + annotations = self.prepared_data["annotations"][ + np.where(self.prepared_data["annotations"]["file_id"] == file_id)[0] ] - annotated_regions = self.annotated_regions[ - np.where(self.annotated_regions["file_id"] == file_id)[0] + annotated_regions = self.prepared_data["annotated_regions"][ + np.where(self.prepared_data["annotated_regions"]["file_id"] == file_id)[0] ] for region in annotated_regions: # find annotations within current region @@ -318,7 +318,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): file = self.get_file(file_id) # get label scope - label_scope = Scopes[self.metadata[file_id]["scope"]] + label_scope = Scopes[self.prepared_data["metadata"][file_id]["scope"]] label_scope_key = f"{label_scope}_label_idx" # @@ -328,7 +328,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.annotations[self.annotations["file_id"] == file_id] + annotations = self.prepared_data["annotations"][self.prepared_data["annotations"]["file_id"] == file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ @@ -364,7 +364,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): y, self.model.example_output.frames, labels=labels ) - metadata = self.metadata[file_id] + metadata = self.prepared_data["metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index fd9eb8e75..259aa3866 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -28,11 +28,11 @@ from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.tasks.segmentation.mixins import SegmentationTask -class VoiceActivityDetection(SegmentationTaskMixin, Task): +class VoiceActivityDetection(SegmentationTask): """Voice activity detection Voice activity detection (or VAD) is the task of detecting speech regions @@ -145,7 +145,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.annotations[self.annotations["file_id"] == file_id] + annotations = self.prepared_data["annotations"][self.prepared_data["annotations"]["file_id"] == file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ @@ -167,7 +167,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): y, self.model.example_output.frames, labels=["speech"] ) - metadata = self.metadata[file_id] + metadata = self.prepared_data["metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id From be6f7ec15f11440cd12a6b4505b91d35fa2077fc Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 7 Nov 2023 08:39:23 +0100 Subject: [PATCH 45/83] add two training tests One for the test of the `MultiLabelSegmentation` task, and the other for the test of the `SupervisedRepresentationLearningWithArcFace` task. --- tests/test_train.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_train.py b/tests/test_train.py index 7a7bfe338..65fcc901e 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -3,10 +3,13 @@ from pytorch_lightning import Trainer from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel +from pyannote.audio.models.embedding.debug import SimpleEmbeddingModel from pyannote.audio.tasks import ( OverlappedSpeechDetection, SpeakerDiarization, VoiceActivityDetection, + MultiLabelSegmentation, + SupervisedRepresentationLearningWithArcFace, ) @@ -24,6 +27,20 @@ def test_train_segmentation(protocol): trainer.fit(model) +def test_train_multilabel_segmentation(protocol): + multilabel_segmentation = MultiLabelSegmentation(protocol) + model = SimpleSegmentationModel(task=multilabel_segmentation) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + +def test_train_supervised_representation_with_arcface(protocol): + supervised_representation_with_arface = SupervisedRepresentationLearningWithArcFace(protocol) + model = SimpleEmbeddingModel(task=supervised_representation_with_arface) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + def test_train_voice_activity_detection(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) From f447bb6a8802001f4b7da01d0493b9717667e2aa Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 7 Nov 2023 11:35:08 +0100 Subject: [PATCH 46/83] assign data directly to task in main process, in `prepare_data` This eliminates the need to reload pickle data in setup when in the main process --- pyannote/audio/core/task.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index c65a98a30..6935c9d8e 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -279,7 +279,7 @@ def prepare_data(self): Notes ----- - Called only once. + Called only once on the main process (and only on it). """ if os.path.exists(self.cache_path): @@ -598,8 +598,11 @@ def prepare_data(self): dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] prepared_data["validation_chunks"] = np.array(validation_chunks, dtype=dtype) + + self.prepared_data = prepared_data + self.has_setup_metadata = True - # cache generated protocol data on disk + # save preparated data on the disk with open(self.cache_path, 'wb') as data_file: pickle.dump(prepared_data, data_file) From 05ccc30ce17de882ca4584b6644d70357fe95c43 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 8 Nov 2023 11:45:31 +0100 Subject: [PATCH 47/83] handle call to `Task.prepare_data` and `Task.setup` under different scenarios --- pyannote/audio/core/task.py | 54 +++++++++++++------ .../audio/tasks/segmentation/multilabel.py | 4 ++ .../overlapped_speech_detection.py | 4 ++ .../tasks/segmentation/speaker_diarization.py | 4 ++ .../segmentation/voice_activity_detection.py | 4 ++ 5 files changed, 53 insertions(+), 17 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 6935c9d8e..519876511 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -26,7 +26,7 @@ import itertools import multiprocessing import numpy as np -import os +from pathlib import Path import pickle import sys import warnings @@ -54,7 +54,6 @@ Subsets = list(Subset.__args__) Scopes = list(Scope.__args__) - # Type of machine learning problem class Problem(Enum): BINARY_CLASSIFICATION = 0 @@ -226,7 +225,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path="./cache/task_cache/prepared_data.pickle" + cache_path=None ): super().__init__() @@ -281,13 +280,15 @@ def prepare_data(self): ----- Called only once on the main process (and only on it). """ - - if os.path.exists(self.cache_path): - # data was already created, do nothing - return - else: - # create the repo - os.makedirs(self.cache_path[:self.cache_path.rfind('/')], exist_ok=True) + if self.cache_path is not None: + cache_path = Path(self.cache_path) + if cache_path.exists(): + # data was already created, do nothing + return + # create a new cache directory at the path specified by the user + else: + cache_rep = Path(self.cache_path[: self.cache_path.rfind('/')]) + cache_rep.mkdir(parents=True, exist_ok=True) # duration of training chunks # TODO: handle variable duration case duration = getattr(self, "duration", 0.0) @@ -603,19 +604,30 @@ def prepare_data(self): self.has_setup_metadata = True # save preparated data on the disk - with open(self.cache_path, 'wb') as data_file: - pickle.dump(prepared_data, data_file) + if self.cache_path is not None: + with open(self.cache_path, 'wb') as cache_file: + pickle.dump(prepared_data, cache_file) + + self.has_prepared_data = True - def setup(self): + def setup(self, stage=None): """Setup""" if not self.has_setup_metadata: + # if no cache directory was provided by the user and task data was not already prepared + if self.cache_path is None and not self.has_prepared_data: + warnings.warn("""No path to the directory containing the cache of prepared data + has been specified. Data preparation will therefore be carried out + on each process used for training. To speed up data preparation, you + can specify a cache directory when instantiating the task.""", stacklevel=1) + self.prepare_data() + return # load data cached by prepare_data method into the task: try: - with open(self.cache_path, 'rb') as data_file: - self.prepared_data = pickle.load(data_file) + with open(self.cache_path, 'rb') as cache_file: + self.prepared_data = pickle.load(cache_file) except FileNotFoundError: - print("Cached data for protocol not found. Ensure that prepare_data was \ - executed correctly and that the path to the task cache is correct") + print("""Cached data for protocol not found. Ensure that prepare_data was + executed correctly and that the path to the task cache is correct""") raise self.has_setup_metadata = True @@ -632,6 +644,14 @@ def specifications( ): self._specifications = specifications + @property + def has_prepared_data(self): + return getattr(self, "_has_prepared_data", False) + + @has_prepared_data.setter + def has_prepared_data(self, value: bool): + self._has_setup_metadata = value + @property def has_setup_metadata(self): return getattr(self, "_has_setup_metadata", False) diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 560b903f1..4fbce9cb2 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -79,6 +79,8 @@ class MultiLabelSegmentation(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to directory where to write or load task caches """ def __init__( @@ -94,6 +96,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path=None, ): if not isinstance(protocol, SegmentationProtocol): raise ValueError( @@ -109,6 +112,7 @@ def __init__( pin_memory=pin_memory, augmentation=augmentation, metric=metric, + cache_path=cache_path, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 78c0d3c9c..84418e52b 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -88,6 +88,8 @@ class OverlappedSpeechDetection(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to directory where to write or load task caches """ OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} @@ -105,6 +107,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path=None, ): super().__init__( protocol, @@ -115,6 +118,7 @@ def __init__( pin_memory=pin_memory, augmentation=augmentation, metric=metric, + cache_path=cache_path, ) self.specifications = Specifications( diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index d42e53a77..a9519eccc 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -110,6 +110,8 @@ class SpeakerDiarization(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to directory where to write or load task caches References ---------- @@ -140,6 +142,7 @@ def __init__( augmentation: BaseWaveformTransform = None, vad_loss: Literal["bce", "mse"] = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path=None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated ): @@ -152,6 +155,7 @@ def __init__( pin_memory=pin_memory, augmentation=augmentation, metric=metric, + cache_path=cache_path, ) if not isinstance(protocol, SpeakerDiarizationProtocol): diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 259aa3866..2cc616170 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -74,6 +74,8 @@ class VoiceActivityDetection(SegmentationTask): metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache_path : str, optional + path to directory where to write or load task caches """ def __init__( @@ -88,6 +90,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + cache_path=None, ): super().__init__( protocol, @@ -98,6 +101,7 @@ def __init__( pin_memory=pin_memory, augmentation=augmentation, metric=metric, + cache_path=cache_path, ) self.balance = balance From 4b8e8a222dd1f49633527d7131ae6c3f0eda27ae Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 9 Nov 2023 10:19:10 +0100 Subject: [PATCH 48/83] add training tests using task caches --- tests/test_train.py | 118 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/tests/test_train.py b/tests/test_train.py index 65fcc901e..4caf4c490 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,4 +1,5 @@ import pytest +from pathlib import Path from pyannote.database import FileFinder, get_protocol from pytorch_lightning import Trainer @@ -12,6 +13,8 @@ SupervisedRepresentationLearningWithArcFace, ) +CACHE_FILE_PATH = "./cache/cache_file" + @pytest.fixture() def protocol(): @@ -27,6 +30,20 @@ def test_train_segmentation(protocol): trainer.fit(model) +def test_train_segmentation_with_cached_data_mono_device(protocol): + first_task = SpeakerDiarization(protocol, cache_path=CACHE_FILE_PATH) + first_model = SimpleSegmentationModel(task=first_task) + first_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + first_trainer.fit(first_model) + + second_task = SpeakerDiarization(protocol, cache_path=CACHE_FILE_PATH) + second_model = SimpleSegmentationModel(task=second_task) + second_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + second_trainer.fit(second_model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_train_multilabel_segmentation(protocol): multilabel_segmentation = MultiLabelSegmentation(protocol) model = SimpleSegmentationModel(task=multilabel_segmentation) @@ -34,6 +51,20 @@ def test_train_multilabel_segmentation(protocol): trainer.fit(model) +def test_train_multilabel_segmentation_with_cached_data_mono_device(protocol): + first_task = MultiLabelSegmentation(protocol, cache_path=CACHE_FILE_PATH) + first_model = SimpleSegmentationModel(task=first_task) + first_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + first_trainer.fit(first_model) + + second_task = MultiLabelSegmentation(protocol, cache_path=CACHE_FILE_PATH) + second_model = SimpleSegmentationModel(task=second_task) + second_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + second_trainer.fit(second_model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_train_supervised_representation_with_arcface(protocol): supervised_representation_with_arface = SupervisedRepresentationLearningWithArcFace(protocol) model = SimpleEmbeddingModel(task=supervised_representation_with_arface) @@ -48,6 +79,20 @@ def test_train_voice_activity_detection(protocol): trainer.fit(model) +def test_train_voice_activity_detection_with_cached_data_mono_device(protocol): + first_task = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + first_model = SimpleSegmentationModel(task=first_task) + first_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + first_trainer.fit(first_model) + + second_task = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + second_model = SimpleSegmentationModel(task=second_task) + second_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + second_trainer.fit(second_model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_train_overlapped_speech_detection(protocol): overlapped_speech_detection = OverlappedSpeechDetection(protocol) model = SimpleSegmentationModel(task=overlapped_speech_detection) @@ -55,6 +100,20 @@ def test_train_overlapped_speech_detection(protocol): trainer.fit(model) +def test_train_overlapped_speech_detection_with_cached_data_mono_device(protocol): + first_task = OverlappedSpeechDetection(protocol, cache_path=CACHE_FILE_PATH) + first_model = SimpleSegmentationModel(task=first_task) + first_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + first_trainer.fit(first_model) + + second_task = OverlappedSpeechDetection(protocol, cache_path=CACHE_FILE_PATH) + second_model = SimpleSegmentationModel(task=second_task) + second_trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=1) + second_trainer.fit(second_model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_finetune_with_task_that_does_not_need_setup_for_specs(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) @@ -79,6 +138,20 @@ def test_finetune_with_task_that_needs_setup_for_specs(protocol): trainer.fit(model) +def test_finetune_with_task_that_needs_setup_for_specs_and_with_cache(protocol): + segmentation = SpeakerDiarization(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=segmentation) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + segmentation = SpeakerDiarization(protocol, cache_path=CACHE_FILE_PATH) + model.task = segmentation + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_transfer_with_task_that_does_not_need_setup_for_specs(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) @@ -116,6 +189,21 @@ def test_finetune_freeze_with_task_that_needs_setup_for_specs(protocol): trainer.fit(model) +def test_finetune_freeze_with_task_that_needs_setup_for_specs_and_with_cache(protocol): + segmentation = SpeakerDiarization(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=segmentation) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + segmentation = SpeakerDiarization(protocol) + model.task = segmentation + model.freeze_up_to("mfcc") + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs(protocol): vad = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=vad) @@ -129,6 +217,36 @@ def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs(protocol): trainer.fit(model) +def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs_and_with_cache(protocol): + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=vad) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model.task = vad + model.freeze_up_to("mfcc") + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + +def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs_and_with_cache(protocol): + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model = SimpleSegmentationModel(task=vad) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + vad = VoiceActivityDetection(protocol, cache_path=CACHE_FILE_PATH) + model.task = vad + model.freeze_up_to("mfcc") + trainer = Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) + + Path(CACHE_FILE_PATH).unlink(missing_ok=True) + + def test_transfer_freeze_with_task_that_does_not_need_setup_for_specs(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) From 45918bd3e99ac8e80624184248fb67344c910b55 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 9 Nov 2023 10:21:51 +0100 Subject: [PATCH 49/83] update `cache_path` type and docstrings --- pyannote/audio/core/task.py | 6 +++--- pyannote/audio/tasks/segmentation/multilabel.py | 4 ++-- .../audio/tasks/segmentation/overlapped_speech_detection.py | 6 +++--- pyannote/audio/tasks/segmentation/speaker_diarization.py | 6 +++--- .../audio/tasks/segmentation/voice_activity_detection.py | 6 +++--- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 519876511..f6d7b9991 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -225,7 +225,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path=None + cache_path: Optional[Union[str, None]] = None ): super().__init__() @@ -278,7 +278,7 @@ def prepare_data(self): Notes ----- - Called only once on the main process (and only on it). + Called only once on the main process (and only on it), for global_rank 0. """ if self.cache_path is not None: cache_path = Path(self.cache_path) @@ -611,7 +611,7 @@ def prepare_data(self): self.has_prepared_data = True def setup(self, stage=None): - """Setup""" + """Setup data on each device""" if not self.has_setup_metadata: # if no cache directory was provided by the user and task data was not already prepared if self.cache_path is None and not self.has_prepared_data: diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 4fbce9cb2..06655fee8 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -80,7 +80,7 @@ class MultiLabelSegmentation(SegmentationTask): Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). cache_path : str, optional - path to directory where to write or load task caches + path to file where to write or load task caches """ def __init__( @@ -96,7 +96,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path=None, + cache_path: Optional[Union[str, None]] = None, ): if not isinstance(protocol, SegmentationProtocol): raise ValueError( diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 84418e52b..e3d6ac259 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -21,7 +21,7 @@ # SOFTWARE. -from typing import Dict, Sequence, Text, Tuple, Union +from typing import Dict, Optional, Sequence, Text, Tuple, Union import numpy as np from pyannote.core import Segment, SlidingWindowFeature @@ -89,7 +89,7 @@ class OverlappedSpeechDetection(SegmentationTask): Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). cache_path : str, optional - path to directory where to write or load task caches + path to file where to write or load task caches """ OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} @@ -107,7 +107,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path=None, + cache_path: Optional[Union[str, None]] = None, ): super().__init__( protocol, diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index a9519eccc..acf5c0279 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -23,7 +23,7 @@ import math import warnings from collections import Counter -from typing import Dict, Literal, Sequence, Text, Tuple, Union +from typing import Dict, Literal, Optional, Sequence, Text, Tuple, Union import numpy as np import torch @@ -111,7 +111,7 @@ class SpeakerDiarization(SegmentationTask): Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). cache_path : str, optional - path to directory where to write or load task caches + path to file where to write or load task caches References ---------- @@ -142,7 +142,7 @@ def __init__( augmentation: BaseWaveformTransform = None, vad_loss: Literal["bce", "mse"] = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path=None, + cache_path: Optional[Union[str, None]] = None, max_num_speakers: int = None, # deprecated in favor of `max_speakers_per_chunk`` loss: Literal["bce", "mse"] = None, # deprecated ): diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 2cc616170..989f5caf2 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Sequence, Text, Tuple, Union +from typing import Dict, Optional, Sequence, Text, Tuple, Union import numpy as np from pyannote.core import Segment, SlidingWindowFeature @@ -75,7 +75,7 @@ class VoiceActivityDetection(SegmentationTask): Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). cache_path : str, optional - path to directory where to write or load task caches + path to file where to write or load task caches """ def __init__( @@ -90,7 +90,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, - cache_path=None, + cache_path: Optional[Union[str, None]] = None, ): super().__init__( protocol, From 980414ef753c946cc1939ec82e84abc67090fbed Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 9 Nov 2023 11:32:35 +0100 Subject: [PATCH 50/83] fix `classes` variable used before assigment This issue occured when a list of classes was specified during `MultiLabelSegmentation` instanciation. --- pyannote/audio/core/task.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index f6d7b9991..4894ec3fc 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -304,11 +304,16 @@ def prepare_data(self): elif isinstance(self.protocol, SegmentationProtocol): classes = getattr(self, "classes", list()) + # save all protocol data in a dict + prepared_data = {} + # make sure classes attribute exists (and set to None if it did not exist) - self.classes = getattr(self, "classes", None) - if self.classes is None: + prepared_data["classes"] = getattr(self, "classes", None) + if prepared_data["classes"] is None: classes = list() # metadata_unique_values["classes"] = list(classes) + else: + classes = prepared_data["classes"] audios = list() # list of path to audio files audio_infos = list() @@ -356,7 +361,7 @@ def prepare_data(self): # if task was not initialized with a fixed list of classes, # we build it as the union of all classes found in files - if self.classes is None: + if prepared_data["classes"] is None: for klass in local_classes: if klass not in classes: classes.append(klass) @@ -368,15 +373,15 @@ def prepare_data(self): # we make sure that all files use a subset of these classes # if they don't, we issue a warning and ignore the extra classes else: - extra_classes = set(local_classes) - set(self.classes) + extra_classes = set(local_classes) - set(prepared_data["classes"]) if extra_classes: warnings.warn( f"Ignoring extra classes ({', '.join(extra_classes)}) found for file {file['uri']} ({file['database']}). " ) annotated_classes.append( [ - self.classes.index(klass) - for klass in set(local_classes) & set(self.classes) + prepared_data["classes"].index(klass) + for klass in set(local_classes) & set(prepared_data["classes"]) ] ) @@ -522,8 +527,6 @@ def prepare_data(self): ] dtype = [(key, "i") for key in metadata_unique_values] - # save all protocol data in a dict - prepared_data = {} prepared_data["metadata"] = np.array(metadata, dtype=dtype) prepared_data["audios"] = np.array(audios, dtype=np.string_) @@ -548,7 +551,7 @@ def prepare_data(self): # into a single (num_files x num_classes) numpy array: # * True indicates that this particular class was annotated for this particular file (though it may not be active in this file) # * False indicates that this particular class was not even annotated (i.e. its absence does not imply that it is not active in this file) - if isinstance(self.protocol, SegmentationProtocol) and self.classes is None: + if isinstance(self.protocol, SegmentationProtocol) and prepared_data["classes"] is None: prepared_data["classes"] = classes annotated_classes_array = np.zeros( (len(annotated_classes), len(classes)), dtype=np.bool_ From c1fbb816028199489a8eb8495d77b0009dc719ae Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 15 Nov 2023 16:02:37 +0100 Subject: [PATCH 51/83] fix: fix residual merge problems --- pyannote/audio/core/model.py | 2 +- pyannote/audio/tasks/segmentation/mixins.py | 5 ----- pyannote/audio/tasks/segmentation/speaker_diarization.py | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 3e9b6a74d..7bec676b4 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -222,7 +222,7 @@ def prepare_data(self): def setup(self, stage=None): if stage == "fit": - self.task.setup_metadata() + self.task.setup() # list of layers before adding task-dependent layers before = set((name, id(module)) for name, module in self.named_modules()) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 00ea47bfc..e1c96c18e 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -23,11 +23,6 @@ import itertools import math import random -<<<<<<< HEAD -======= -import warnings -from collections import defaultdict ->>>>>>> feat/joint-diarization-and-embedding from typing import Dict, Sequence, Union import matplotlib.pyplot as plt diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 88dff8bef..acf5c0279 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -23,7 +23,7 @@ import math import warnings from collections import Counter -from typing import Dict, Literal, Sequence, Text, Tuple, Union +from typing import Dict, Literal, Optional, Sequence, Text, Tuple, Union import numpy as np import torch From 987e702987fa86a04bce1ef5f760f49200f8eaa0 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 21 Nov 2023 16:14:37 +0100 Subject: [PATCH 52/83] improve code readability --- pyannote/audio/core/task.py | 42 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 4894ec3fc..42d0017b2 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -286,9 +286,7 @@ def prepare_data(self): # data was already created, do nothing return # create a new cache directory at the path specified by the user - else: - cache_rep = Path(self.cache_path[: self.cache_path.rfind('/')]) - cache_rep.mkdir(parents=True, exist_ok=True) + cache_path.parent.mkdir(parents=True, exist_ok=True) # duration of training chunks # TODO: handle variable duration case duration = getattr(self, "duration", 0.0) @@ -615,24 +613,26 @@ def prepare_data(self): def setup(self, stage=None): """Setup data on each device""" - if not self.has_setup_metadata: - # if no cache directory was provided by the user and task data was not already prepared - if self.cache_path is None and not self.has_prepared_data: - warnings.warn("""No path to the directory containing the cache of prepared data - has been specified. Data preparation will therefore be carried out - on each process used for training. To speed up data preparation, you - can specify a cache directory when instantiating the task.""", stacklevel=1) - self.prepare_data() - return - # load data cached by prepare_data method into the task: - try: - with open(self.cache_path, 'rb') as cache_file: - self.prepared_data = pickle.load(cache_file) - except FileNotFoundError: - print("""Cached data for protocol not found. Ensure that prepare_data was - executed correctly and that the path to the task cache is correct""") - raise - self.has_setup_metadata = True + # if all data was assigned to the task, nothing to do + if self.has_setup_metadata: + return + # if no cache directory was provided by the user and task data was not already prepared + if self.cache_path is None and not self.has_prepared_data: + warnings.warn("No path to the directory containing the cache of prepared data" + " has been specified. Data preparation will therefore be carried out" + " on each process used for training. To speed up data preparation, you" + " can specify a cache directory when instantiating the task.", stacklevel=1) + self.prepare_data() + return + # load data cached by prepare_data method into the task: + try: + with open(self.cache_path, 'rb') as cache_file: + self.prepared_data = pickle.load(cache_file) + except FileNotFoundError: + print("""Cached data for protocol not found. Ensure that prepare_data was + executed correctly and that the path to the task cache is correct""") + raise + self.has_setup_metadata = True @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: From 042dc437e6accddae42465239c5a161ec9b06f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pag=C3=A9s?= <55240756+clement-pages@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:22:50 +0100 Subject: [PATCH 53/83] improve: use `numpy` method for w/r task cache instead `pickle` (#1) * use npz archive instead pickle to save task data * improve code readability * improve(task): update numpy array dtypes In order to use types whose size better machtes the contents of the arrays * remove `end` entry from `annotated_regions` numpy array This entry was redundant with the start and duration entries, since `end` = `start` + `duration`. * fix: allow data preparation to be finished when task has no validation * improve: clear data lists after assignation to `self.prepared_data` This is to avoid data redundancy in the `prepare_data` method --------- Co-authored-by: clement-pages --- pyannote/audio/core/task.py | 91 +++++++++++-------- pyannote/audio/tasks/segmentation/mixins.py | 4 +- .../tasks/segmentation/speaker_diarization.py | 2 +- 3 files changed, 54 insertions(+), 43 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 42d0017b2..7520c3a96 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -27,7 +27,6 @@ import multiprocessing import numpy as np from pathlib import Path -import pickle import sys import warnings from dataclasses import dataclass @@ -283,7 +282,7 @@ def prepare_data(self): if self.cache_path is not None: cache_path = Path(self.cache_path) if cache_path.exists(): - # data was already created, do nothing + # data was already created, nothing to do return # create a new cache directory at the path specified by the user cache_path.parent.mkdir(parents=True, exist_ok=True) @@ -450,7 +449,6 @@ def prepare_data(self): file_id, segment.duration, segment.start, - segment.end, ) annotated_regions.append(annotated_region) @@ -523,27 +521,32 @@ def prepare_data(self): tuple(metadatum.get(key, -1) for key in metadata_unique_values) for metadatum in metadata ] - dtype = [(key, "i") for key in metadata_unique_values] + dtype = [(key, "b") for key in metadata_unique_values] prepared_data["metadata"] = np.array(metadata, dtype=dtype) + metadata.clear() prepared_data["audios"] = np.array(audios, dtype=np.string_) + audios.clear() # turn list of files metadata into a single numpy array # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 dtype = [ ("sample_rate", "i"), ("num_frames", "i"), - ("num_channels", "i"), - ("bits_per_sample", "i"), + ("num_channels", "B"), + ("bits_per_sample", "B"), ] prepared_data["audio_infos"] = np.array(audio_infos, dtype=dtype) + audio_infos.clear() prepared_data["audio_encodings"] = np.array(audio_encodings, dtype=np.string_) + audio_encodings.clear() prepared_data["annotated_duration"] = np.array(annotated_duration) + annotated_duration.clear() # turn list of annotated regions into a single numpy array - dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] - annotated_regions_array = np.array(annotated_regions, dtype=dtype) - prepared_data["annotated_regions"] = annotated_regions_array + dtype = [("file_id", "i"), ("duration", "f"), ("start", "f")] + prepared_data["annotated_regions"] = np.array(annotated_regions, dtype=dtype) + annotated_regions.clear() # convert annotated_classes (which is a list of list of classes, one list of classes per file) # into a single (num_files x num_classes) numpy array: @@ -557,6 +560,8 @@ def prepare_data(self): for file_id, classes in enumerate(annotated_classes): annotated_classes_array[file_id, classes] = True prepared_data["annotated_classes"] = annotated_classes_array + annotated_classes.clear() + del annotated_classes_array # turn list of annotations into a single numpy array dtype = [ @@ -569,45 +574,46 @@ def prepare_data(self): ] prepared_data["annotations"] = np.array(annotations, dtype=dtype) + annotations.clear() prepared_data["metadata_unique_values"] = metadata_unique_values + metadata_unique_values.clear() - if not self.has_validation: - return + if self.has_validation: + validation_chunks = list() + + # obtain indexes of files in the validation subset + validation_file_ids = np.where( + prepared_data["metadata"]["subset"] == Subsets.index("development") + )[0] + + # iterate over files in the validation subset + for file_id in validation_file_ids: + # get annotated regions in file + annotated_regions = prepared_data["annotated_regions"][ + prepared_data["annotated_regions"]["file_id"] == file_id + ] + + # iterate over annotated regions + for annotated_region in annotated_regions: + # number of chunks in annotated region + num_chunks = round(annotated_region["duration"] // duration) + + # iterate over chunks + for c in range(num_chunks): + start_time = annotated_region["start"] + c * duration + validation_chunks.append((file_id, start_time, duration)) + + dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] + prepared_data["validation_chunks"] = np.array(validation_chunks, dtype=dtype) + validation_chunks.clear() - validation_chunks = list() - - # obtain indexes of files in the validation subset - validation_file_ids = np.where( - prepared_data["metadata"]["subset"] == Subsets.index("development") - )[0] - - # iterate over files in the validation subset - for file_id in validation_file_ids: - # get annotated regions in file - annotated_regions = annotated_regions_array[ - annotated_regions_array["file_id"] == file_id - ] - - # iterate over annotated regions - for annotated_region in annotated_regions: - # number of chunks in annotated region - num_chunks = round(annotated_region["duration"] // duration) - - # iterate over chunks - for c in range(num_chunks): - start_time = annotated_region["start"] + c * duration - validation_chunks.append((file_id, start_time, duration)) - - dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] - prepared_data["validation_chunks"] = np.array(validation_chunks, dtype=dtype) - self.prepared_data = prepared_data self.has_setup_metadata = True # save preparated data on the disk if self.cache_path is not None: with open(self.cache_path, 'wb') as cache_file: - pickle.dump(prepared_data, cache_file) + np.savez_compressed(cache_file, **prepared_data) self.has_prepared_data = True @@ -627,13 +633,14 @@ def setup(self, stage=None): # load data cached by prepare_data method into the task: try: with open(self.cache_path, 'rb') as cache_file: - self.prepared_data = pickle.load(cache_file) + self.prepared_data = dict(np.load(cache_file, allow_pickle=True)) except FileNotFoundError: print("""Cached data for protocol not found. Ensure that prepare_data was executed correctly and that the path to the task cache is correct""") raise self.has_setup_metadata = True + @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: # setup metadata on-demand the first time specifications are requested and missing @@ -649,6 +656,8 @@ def specifications( @property def has_prepared_data(self): + # This flag indicates if data for this task was generated, and + # optionally saved on the disk return getattr(self, "_has_prepared_data", False) @has_prepared_data.setter @@ -657,6 +666,8 @@ def has_prepared_data(self, value: bool): @property def has_setup_metadata(self): + # This flag indicates if data was assigned to this task, directly from prepared + # data or by reading in a cached file on the disk return getattr(self, "_has_setup_metadata", False) @has_setup_metadata.setter diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index e1c96c18e..4b1dc8523 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -138,8 +138,8 @@ def train__iter__helper(self, rng: random.Random, **filters): ) # select one chunk at random in this annotated region - _, _, start, end = self.prepared_data["annotated_regions"][annotated_region_index] - start_time = rng.uniform(start, end - duration) + _, region_duration, start = self.prepared_data["annotated_regions"][annotated_region_index] + start_time = rng.uniform(start, start + region_duration - duration) yield self.prepare_chunk(file_id, start_time, duration) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index acf5c0279..99a6e2904 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -211,7 +211,7 @@ def setup(self): for region in annotated_regions: # find annotations within current region region_start = region["start"] - region_end = region["end"] + region_end = region["start"] + region["duration"] region_annotations = annotations[ np.where( (annotations["start"] >= region_start) From 001187047a09eda2f8c890fdacf72483d487574e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 29 Nov 2023 08:40:30 +0100 Subject: [PATCH 54/83] improve: remove complete redefinition of `setup` in joint task Now the joint task uses `prepare_data` and `setup` from core `Task` and `SpeakerDiarization` task. --- .../speaker_diarization_and_embedding.py | 292 +----------------- 1 file changed, 11 insertions(+), 281 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 78826f114..cc0fd05dc 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -25,7 +25,7 @@ import math import itertools import random -from typing import Literal, Union, Sequence, Dict +from typing import Dict, Literal, Optional, Sequence, Union import warnings from matplotlib import pyplot as plt import numpy as np @@ -97,7 +97,8 @@ def __init__( margin : float = 28.6, scale: float = 64.0, alpha: float = 0.5, - augmentation: BaseWaveformTransform = None + augmentation: BaseWaveformTransform = None, + cache_path: Optional[Union[str, None]] = None, ) -> None: super().__init__( protocol, @@ -106,6 +107,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + cache_path=cache_path, ) self.weigh_by_cardinality = weigh_by_cardinality @@ -116,11 +118,10 @@ def __init__( self.scale = scale self.alpha = alpha - # keep track of the use of database available in the meta protocol # * embedding databases are those with global speaker label scope # * diarization databases are those with file or database speaker label scope - self.embedding_database_files = [] + self.global_files_id = [] def get_file(self, file_id): @@ -155,277 +156,10 @@ def setup(self, stage="fit"): Setup stage. Defaults to 'fit'. """ - # duration of training chunks - # TODO: handle variable duration case - duration = getattr(self, "duration", 0.0) - - # list of possible values for each metadata key - metadata_unique_values = defaultdict(list) - - metadata_unique_values["subset"] = Subsets - - if isinstance(self.protocol, SpeakerDiarizationProtocol): - metadata_unique_values["scope"] = Scopes - - # make sure classes attribute exists (and set to None if it did not exist) - self.classes = getattr(self, "classes", None) - if self.classes is None: - classes = list() - # metadata_unique_values["classes"] = list(classes) - - audios = list() # list of path to audio files - audio_infos = list() - audio_encodings = list() - metadata = list() # list of metadata - - annotated_duration = list() # total duration of annotated regions (per file) - annotated_regions = list() # annotated regions - annotations = list() # actual annotations - annotated_classes = list() # list of annotated classes (per file) - unique_labels = list() - - if self.has_validation: - files_iter = itertools.chain( - self.protocol.train(), self.protocol.development() - ) - else: - files_iter = self.protocol.train() - - for file_id, file in enumerate(files_iter): - - # gather metadata and update metadata_unique_values so that each metadatum - # (e.g. source database or label) is represented by an integer. - metadatum = dict() - - # keep track of source database and subset (train, development, or test) - if file["database"] not in metadata_unique_values["database"]: - metadata_unique_values["database"].append(file["database"]) - metadatum["database"] = metadata_unique_values["database"].index( - file["database"] - ) - metadatum["subset"] = Subsets.index(file["subset"]) - - # keep track of speaker label scope (file, database, or global) for speaker diarization protocols - if isinstance(self.protocol, SpeakerDiarizationProtocol): - metadatum["scope"] = Scopes.index(file["scope"]) - # keep track of files where speaker label scope is global for embedding subtask - if file["scope"] == 'global': - self.embedding_database_files.append(file_id) - - remaining_metadata_keys = set(file) - set( - [ - "uri", - "database", - "subset", - "audio", - "torchaudio.info", - "scope", - "classes", - "annotation", - "annotated", - ] - ) - - # keep track of any other (integer or string) metadata provided by the protocol - # (e.g. a "domain" key for domain-adversarial training) - for key in remaining_metadata_keys: - - value = file[key] - - if isinstance(value, str): - if value not in metadata_unique_values[key]: - metadata_unique_values[key].append(value) - metadatum[key] = metadata_unique_values[key].index(value) - - elif isinstance(value, int): - metadatum[key] = value - - else: - warnings.warn( - f"Ignoring '{key}' metadata because of its type ({type(value)}). Only str and int are supported for now.", - category=UserWarning, - ) - - metadata.append(metadatum) - - database_unique_labels = list() - - # reset list of file-scoped labels - file_unique_labels = list() - - # path to audio file - audios.append(str(file["audio"])) - - # audio info - audio_info = file["torchaudio.info"] - audio_infos.append( - ( - audio_info.sample_rate, # sample rate - audio_info.num_frames, # number of frames - audio_info.num_channels, # number of channels - audio_info.bits_per_sample, # bits per sample - ) - ) - audio_encodings.append(audio_info.encoding) # encoding - - # annotated regions and duration - _annotated_duration = 0.0 - for segment in file["annotated"]: - - # skip annotated regions that are shorter than training chunk duration - if segment.duration < duration: - continue - - # append annotated region - annotated_region = ( - file_id, - segment.duration, - segment.start, - segment.end, - ) - annotated_regions.append(annotated_region) - - # increment annotated duration - _annotated_duration += segment.duration - - # append annotated duration - annotated_duration.append(_annotated_duration) - - # annotations - for segment, _, label in file["annotation"].itertracks(yield_label=True): - - # "scope" is provided by speaker diarization protocols to indicate - # whether speaker labels are local to the file ('file'), consistent across - # all files in a database ('database'), or globally consistent ('global') - - if "scope" in file: - - # 0 = 'file' - # 1 = 'database' - # 2 = 'global' - scope = Scopes.index(file["scope"]) - - # update list of file-scope labels - if label not in file_unique_labels: - file_unique_labels.append(label) - # and convert label to its (file-scope) index - file_label_idx = file_unique_labels.index(label) - - database_label_idx = global_label_idx = -1 - - if scope > 0: # 'database' or 'global' - - # update list of database-scope labels - if label not in database_unique_labels: - database_unique_labels.append(label) - - # and convert label to its (database-scope) index - database_label_idx = database_unique_labels.index(label) - - if scope > 1: # 'global' - - # update list of global-scope labels - if label not in unique_labels: - unique_labels.append(label) - # and convert label to its (global-scope) index - global_label_idx = unique_labels.index(label) - - # basic segmentation protocols do not provide "scope" information - # as classes are global by definition - - else: - try: - file_label_idx = ( - database_label_idx - ) = global_label_idx = classes.index(label) - except ValueError: - # skip labels that are not in the list of classes - continue - - annotations.append( - ( - file_id, # index of file - segment.start, # start time - segment.end, # end time - file_label_idx, # file-scope label index - database_label_idx, # database-scope label index - global_label_idx, # global-scope index - ) - ) - - # since not all metadata keys are present in all files, fallback to -1 when a key is missing - metadata = [ - tuple(metadatum.get(key, -1) for key in metadata_unique_values) - for metadatum in metadata - ] - dtype = [(key, "i") for key in metadata_unique_values] - self.metadata = np.array(metadata, dtype=dtype) - - # NOTE: read with str(self.audios[file_id], encoding='utf-8') - self.audios = np.array(audios, dtype=np.string_) - - # turn list of files metadata into a single numpy array - # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 - - dtype = [ - ("sample_rate", "i"), - ("num_frames", "i"), - ("num_channels", "i"), - ("bits_per_sample", "i"), - ] - self.audio_infos = np.array(audio_infos, dtype=dtype) - self.audio_encodings = np.array(audio_encodings, dtype=np.string_) - - self.annotated_duration = np.array(annotated_duration) - - # turn list of annotated regions into a single numpy array - dtype = [("file_id", "i"), ("duration", "f"), ("start", "f"), ("end", "f")] - self.annotated_regions = np.array(annotated_regions, dtype=dtype) - - # turn list of annotations into a single numpy array - dtype = [ - ("file_id", "i"), - ("start", "f"), - ("end", "f"), - ("file_label_idx", "i"), - ("database_label_idx", "i"), - ("global_label_idx", "i"), - ] - self.annotations = np.array(annotations, dtype=dtype) - - self.metadata_unique_values = metadata_unique_values - - if not self.has_validation: - return - - validation_chunks = list() - - # obtain indexes of files in the validation subset - validation_file_ids = np.where( - self.metadata["subset"] == Subsets.index("development") - )[0] - - # iterate over files in the validation subset - for file_id in validation_file_ids: - - # get annotated regions in file - annotated_regions = self.annotated_regions[ - self.annotated_regions["file_id"] == file_id - ] - - # iterate over annotated regions - for annotated_region in annotated_regions: - - # number of chunks in annotated region - num_chunks = round(annotated_region["duration"] // duration) - - # iterate over chunks - for c in range(num_chunks): - start_time = annotated_region["start"] + c * duration - validation_chunks.append((file_id, start_time, duration)) - - dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] - self.validation_chunks = np.array(validation_chunks, dtype=dtype) + super().setup() + global_scope_mask = self.prepared_data["annotations"]["global_label_idx"] > -1 + self.global_files_id = np.unique(self.prepared_data["annotations"]["file_id"][global_scope_mask]) + global_classes = np.unique(self.prepared_data["annotations"]["global_label_idx"][global_scope_mask]) speaker_diarization = Specifications( duration=self.duration, @@ -439,9 +173,8 @@ def setup(self, stage="fit"): duration=self.duration, resolution=Resolution.CHUNK, problem=Problem.REPRESENTATION, - classes=unique_labels, + classes=global_classes, ) - self.specifications = (speaker_diarization, speaker_embedding) def prepare_chunk(self, file_id: int, start_time: float, duration: float): @@ -653,21 +386,18 @@ def train__iter__helper(self, rng : random.Random, **filters): # choose between diarization or embedding subtask according to a ratio # between these two tasks if np.random.uniform() < self.database_ratio: - subtask = Subtasks.index("diarization") file_id, start_time = self.draw_diarization_chunk(file_ids, prob_annotated_duration, rng, duration) else: - subtask = Subtasks.index("embedding") # shuffle embedding classes list and go through this shuffled list # to make sure to see all the speakers during training if embedding_class_idx == len(shuffled_embedding_classes): rng.shuffle(shuffled_embedding_classes) embedding_class_idx = 0 klass = shuffled_embedding_classes[embedding_class_idx] - class_id = embedding_classes.index(klass) embedding_class_idx += 1 - file_id, start_time = self.draw_embedding_chunk(class_id, duration) + file_id, start_time = self.draw_embedding_chunk(klass, duration) sample = self.prepare_chunk(file_id, start_time, duration) yield sample From 6e6b62db7f1d630bddd453d1999dd92f08f1679b Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 29 Nov 2023 15:40:59 +0100 Subject: [PATCH 55/83] improve: remove duplicated attributes in `JointSpeakerDiarizationAndEmbedding` --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index cc0fd05dc..65679ffd1 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -103,6 +103,9 @@ def __init__( super().__init__( protocol, duration=duration, + max_speakers_per_chunk=max_speakers_per_chunk, + max_speakers_per_frame=max_speakers_per_frame, + weigh_by_cardinality=weigh_by_cardinality, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -110,14 +113,10 @@ def __init__( cache_path=cache_path, ) - self.weigh_by_cardinality = weigh_by_cardinality - self.max_speakers_per_chunk = max_speakers_per_chunk - self.max_speakers_per_frame = max_speakers_per_frame self.database_ratio = database_ratio self.margin = margin self.scale = scale self.alpha = alpha - # keep track of the use of database available in the meta protocol # * embedding databases are those with global speaker label scope # * diarization databases are those with file or database speaker label scope From e60873cd9ab75dd03b6fb32c6b788d6cb0ee528c Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 29 Nov 2023 15:47:16 +0100 Subject: [PATCH 56/83] update: replace old `Task` attributes with prepared_data in joint task --- .../speaker_diarization_and_embedding.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 65679ffd1..ce537c1c1 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -126,10 +126,10 @@ def get_file(self, file_id): file = dict() - file["audio"] = str(self.audios[file_id], encoding="utf-8") + file["audio"] = str(self.prepared_data["audios"][file_id], encoding="utf-8") - _audio_info = self.audio_infos[file_id] - _encoding = self.audio_encodings[file_id] + _audio_info = self.prepared_data["audio_infos"][file_id] + _encoding = self.prepared_data["audio_encodings"][file_id] sample_rate = _audio_info["sample_rate"] num_frames = _audio_info["num_frames"] @@ -207,7 +207,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): file = self.get_file(file_id) # get label scope - label_scope = Scopes[self.metadata[file_id]["scope"]] + label_scope = Scopes[self.prepared_data["metadata"][file_id]["scope"]] label_scope_key = f"{label_scope}_label_idx" chunk = Segment(start_time, start_time + duration) @@ -216,7 +216,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration, mode="pad") # gather all annotations of current file - annotations = self.annotations[self.annotations["file_id"] == file_id] + annotations = self.prepared_data["annotations"][self.prepared_data["annotations"]["file_id"] == file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ @@ -252,7 +252,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["y"] = SlidingWindowFeature( y, self.model.example_output[0].frames, labels=labels ) - metadata = self.metadata[file_id] + metadata = self.prepared_data["metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id @@ -281,13 +281,13 @@ def draw_diarization_chunk(self, file_ids : np.ndarray, file_id = np.random.choice(file_ids, p=prob_annotated_duration) # find indices of annotated regions in this file annotated_region_indices = np.where( - self.annotated_regions["file_id"] == file_id + self.prepared_data["annotated_regions"]["file_id"] == file_id )[0] # turn annotated regions duration into a probability distribution - prob_annotaded_regions_duration = self.annotated_regions["duration"][ + prob_annotaded_regions_duration = self.prepared_data["annotated_regions"]["duration"][ annotated_region_indices - ] / np.sum(self.annotated_regions["duration"][annotated_region_indices]) + ] / np.sum(self.prepared_data["annotated_regions"]["duration"][annotated_region_indices]) # seletect one annotated region at random (with probability proportional to its duration) annotated_region_index = np.random.choice(annotated_region_indices, @@ -295,8 +295,8 @@ def draw_diarization_chunk(self, file_ids : np.ndarray, ) # select one chunk at random in this annotated region - _, _, start, end = self.annotated_regions[annotated_region_index] - start_time = rng.uniform(start, end - duration) + _, region_duration, start = self.prepared_data["annotated_regions"][annotated_region_index] + start_time = rng.uniform(start, start + region_duration - duration) return (file_id, start_time) @@ -321,8 +321,8 @@ class ID in the task speficiations """ # get index of the current class in the order of original class list # get segments for current class - class_segments_idx = self.annotations["global_label_idx"] == class_id - class_segments = self.annotations[class_segments_idx] + class_segments_idx = self.prepared_data["annotations"]["global_label_idx"] == class_id + class_segments = self.prepared_data["annotations"][class_segments_idx] # sample one segment from all the class segments: segments_duration = class_segments["end"] - class_segments["start"] @@ -353,14 +353,14 @@ def train__iter__helper(self, rng : random.Random, **filters): """ # indices of training files that matches domain filters - training = self.metadata["subset"] == Subsets.index("train") + training = self.prepared_data["metadata"]["subset"] == Subsets.index("train") for key, value in filters.items(): - training &= self.metadata[key] == value + training &= self.prepared_data["metadata"][key] == value file_ids = np.where(training)[0] # get the subset of embedding database files from training files - embedding_files_ids = file_ids[np.in1d(file_ids, self.embedding_database_files)] + embedding_files_ids = file_ids[np.in1d(file_ids, self.global_files_id)] - annotated_duration = self.annotated_duration[file_ids] + annotated_duration = self.prepared_data["annotated_duration"][file_ids] # set duration of files for the embedding part to zero, in order to not # drawn them for diarization part annotated_duration[embedding_files_ids] = 0. @@ -425,7 +425,7 @@ def train__iter__(self): else: # create subchunks = dict() - for product in itertools.product([self.metadata_unique_values[key] for key in balance]): + for product in itertools.product([self.prepared_data["metadata_unique_values"][key] for key in balance]): filters = {key : value for key, value in zip(balance, product)} subchunks[product] = self.train__iter__helper(rng, **filters) @@ -464,7 +464,7 @@ def collate_y(self, batch) -> torch.Tensor: labels = b["y"].labels num_speakers = len(labels) # embedding reference - y_emb = np.full((self.max_speakers_per_chunk,), -1, dtype=np.int) + y_emb = np.full((self.max_speakers_per_chunk,), -1, dtype=int) if num_speakers > self.max_speakers_per_chunk: # sort speakers in descending talkativeness order From 40cc903c1760ea96c74168c70164247959b12d61 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 29 Nov 2023 16:02:04 +0100 Subject: [PATCH 57/83] improve: handle multi-speaker embeddings in `example_output` --- pyannote/audio/core/model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 7bec676b4..216631d4f 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -202,8 +202,14 @@ def __example_output( _, num_frames, dimension = example_output.shape frame_duration = specifications.duration / num_frames frames = SlidingWindow(step=frame_duration, duration=frame_duration) + # else chunk resolution (Resolution.CHUNK) else: - _, dimension = example_output.shape + # if the model outputs only one embedding (mono-speaker case): + if len(example_output.shape) == 2: + _, dimension = example_output.shape + # if the model returns multiple embeddings (multi-speaker case): + else: + _, _, dimension = example_output.shape num_frames = None frames = None From 30ae9fbe3d64c62b7b43c49f45e98e2ec87a4e7a Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 30 Nov 2023 11:00:31 +0100 Subject: [PATCH 58/83] feat: add new end-to-end model for joint speaker diarization and embeddins This new model is based on a `WeSpeakerResnet34` for the speaker embeddings extraction part, and on `PyanNet` for (local) segmentation. --- .../models/joint/end_to_end_diarization.py | 57 ++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index ec8dbcf88..591e32836 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -20,7 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Literal, Optional +import os +from typing import List, Literal, Optional, Union from warnings import warn from einops import rearrange @@ -42,6 +43,60 @@ Subtasks = list(Subtask.__args__) +class WeSpeakerBasesEndToEndDiarization(Model): + """ + WeSpeaker-based joint speaker diarization and speaker + embedding extraction model + """ + def __init__( + self, + sincnet: dict = None, + lstm: dict = None, + linear: dict = None, + sample_rate=16000, + embedding_dim=256, + num_channels=1, + task: Optional[Union[Task, None]] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + # speakers embedding extraction submodel: + self.resnet34 = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") + self.embedding_dim = embedding_dim + self.save_hyperparameters("embedding_dim") + + # speaker segmentation submodel: + self.pyannet = Model.from_pretrained( + "pyannote/segmentation-3.0", + use_auth_token=os.environ["HUGGINGFACE_TOKEN"], + strict=False + ) + + def build(self): + """""" + dia_specs = self.specifications[Subtasks.index("diarization")] + self.pyannet.specifications = dia_specs + self.pyannet.build() + self.powerset = Powerset( + len(dia_specs.classes), + dia_specs.powerset_max_classes, + ) + + def forward(self, waveformms: torch.Tensor) -> torch.Tensor: + """ + + Parameters + ---------- + waveforms : torch.Tensor + Batch of waveforms with shape (batch, channel, sample) + """ + dia_outputs = self.pyannet(waveformms) + weights = self.powerset.to_multilabel(dia_outputs) + weights = rearrange(weights, "b f s -> b s f") + emb_outputs = self.resnet34(waveformms, weights) + return (dia_outputs, emb_outputs) + + class SpeakerEndToEndDiarization(Model): """Speaker End-to-End Diarization and Embedding model SINCNET -- TDNN .. TDNN -- TDNN ..TDNN -- StatsPool -- Linear -- Classifier From 72f99165d87f50bb62fd959ab28e1c44e56dd51e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 30 Nov 2023 14:31:07 +0100 Subject: [PATCH 59/83] fix: fix empty dict issue for `metadata_unique_values` in `prepared_data` --- pyannote/audio/core/task.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 7520c3a96..8822689f1 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -576,7 +576,6 @@ def prepare_data(self): prepared_data["annotations"] = np.array(annotations, dtype=dtype) annotations.clear() prepared_data["metadata_unique_values"] = metadata_unique_values - metadata_unique_values.clear() if self.has_validation: validation_chunks = list() From ecd2cb4443444298c8af53dce1cf6bf632e4d6cb Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 30 Nov 2023 14:26:42 +0100 Subject: [PATCH 60/83] improve: add dynamic typing for np array in `prepare_data` --- pyannote/audio/core/task.py | 60 +++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 8822689f1..6e51680b6 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -279,6 +279,38 @@ def prepare_data(self): ----- Called only once on the main process (and only on it), for global_rank 0. """ + + def get_smallest_type(value: int, unsigned: Optional[bool]=False) -> str: + """Return the most suitable type for storing the + value passed in parameter in memory. + + Parameters + ---------- + value: int + value whose type is best suited to storage in memory + unsigned: bool, optional + positive integer mode only. Default to False + Returns + ------- + str: + numpy formatted type + (see https://numpy.org/doc/stable/reference/arrays.dtypes.html) + """ + if unsigned: + if value < 0: + raise ValueError( + f"negative value ({value}) is incompatible with unsigned types" + ) + # unsigned byte (8 bits), unsigned short (16 bits), unsigned int (32 bits) + types_list = [(255, 'B'), (65_535, 'u2'), (4_294_967_296, 'u4')] + else: + # signe byte (8 bits), signed short (16 bits), signed int (32 bits): + types_list = [(127, 'b'), (32_768, 'i2'), (2_147_483_648, 'i')] + filtered_list = [(max_val, type) for max_val, type in types_list if max_val > abs(value)] + if not filtered_list: + return 'u8' if unsigned else 'i8' # unsigned or signed long (64 bits) + return filtered_list[0][1] + if self.cache_path is not None: cache_path = Path(self.cache_path) if cache_path.exists(): @@ -521,7 +553,9 @@ def prepare_data(self): tuple(metadatum.get(key, -1) for key in metadata_unique_values) for metadatum in metadata ] - dtype = [(key, "b") for key in metadata_unique_values] + dtype = [ + (key, get_smallest_type(max(m[i] for m in metadata))) for i, key in enumerate(metadata_unique_values) + ] prepared_data["metadata"] = np.array(metadata, dtype=dtype) metadata.clear() @@ -531,8 +565,8 @@ def prepare_data(self): # turn list of files metadata into a single numpy array # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 dtype = [ - ("sample_rate", "i"), - ("num_frames", "i"), + ("sample_rate", get_smallest_type(max(ai[0] for ai in audio_infos), unsigned=True)), + ("num_frames", get_smallest_type(max(ai[1] for ai in audio_infos), unsigned=True)), ("num_channels", "B"), ("bits_per_sample", "B"), ] @@ -544,7 +578,11 @@ def prepare_data(self): annotated_duration.clear() # turn list of annotated regions into a single numpy array - dtype = [("file_id", "i"), ("duration", "f"), ("start", "f")] + dtype = [ + ("file_id", get_smallest_type(max(ar[0] for ar in annotated_regions), unsigned=True)), + ("duration", "f"), + ("start", "f") + ] prepared_data["annotated_regions"] = np.array(annotated_regions, dtype=dtype) annotated_regions.clear() @@ -565,12 +603,12 @@ def prepare_data(self): # turn list of annotations into a single numpy array dtype = [ - ("file_id", "i"), + ("file_id", get_smallest_type(max(a[0] for a in annotations), unsigned=True)), ("start", "f"), ("end", "f"), - ("file_label_idx", "i"), - ("database_label_idx", "i"), - ("global_label_idx", "i"), + ("file_label_idx", get_smallest_type(max(a[3] for a in annotations))), + ("database_label_idx", get_smallest_type(max(a[4] for a in annotations))), + ("global_label_idx", get_smallest_type(max(a[5] for a in annotations))), ] prepared_data["annotations"] = np.array(annotations, dtype=dtype) @@ -602,7 +640,11 @@ def prepare_data(self): start_time = annotated_region["start"] + c * duration validation_chunks.append((file_id, start_time, duration)) - dtype = [("file_id", "i"), ("start", "f"), ("duration", "f")] + dtype = [ + ("file_id", get_smallest_type(max(v[0] for v in validation_chunks), unsigned=True)), + ("start", "f"), + ("duration", "f") + ] prepared_data["validation_chunks"] = np.array(validation_chunks, dtype=dtype) validation_chunks.clear() From fb6d5406d074c058e5d1336ec81d27e88ebaef07 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 4 Dec 2023 15:21:04 +0100 Subject: [PATCH 61/83] improve: check matching bewteen task current protocol and cached protocol --- pyannote/audio/core/task.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 6e51680b6..1b0db4b6b 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -336,6 +336,7 @@ def get_smallest_type(value: int, unsigned: Optional[bool]=False) -> str: # save all protocol data in a dict prepared_data = {} + prepared_data["protocol_name"] = self.protocol.name # make sure classes attribute exists (and set to None if it did not exist) prepared_data["classes"] = getattr(self, "classes", None) if prepared_data["classes"] is None: @@ -599,7 +600,6 @@ def get_smallest_type(value: int, unsigned: Optional[bool]=False) -> str: annotated_classes_array[file_id, classes] = True prepared_data["annotated_classes"] = annotated_classes_array annotated_classes.clear() - del annotated_classes_array # turn list of annotations into a single numpy array dtype = [ @@ -679,6 +679,12 @@ def setup(self, stage=None): print("""Cached data for protocol not found. Ensure that prepare_data was executed correctly and that the path to the task cache is correct""") raise + # checks that the task current protocol matches the cached protocol + if self.protocol.name != self.prepared_data["protocol_name"]: + raise ValueError( + f"Protocol specified for the task ({self.protocol.name}) " + f"does not correspond to the cached one ({self.prepared_data['protocol_name']})" + ) self.has_setup_metadata = True From 3810308f5cd128a6b34375a9d45306f6f2c4c9bf Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 4 Dec 2023 15:35:37 +0100 Subject: [PATCH 62/83] remove: remove unused argument `stage` in `Task.setup` --- pyannote/audio/core/task.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 1b0db4b6b..b3d8b35cc 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -658,7 +658,7 @@ def get_smallest_type(value: int, unsigned: Optional[bool]=False) -> str: self.has_prepared_data = True - def setup(self, stage=None): + def setup(self): """Setup data on each device""" # if all data was assigned to the task, nothing to do if self.has_setup_metadata: @@ -668,7 +668,8 @@ def setup(self, stage=None): warnings.warn("No path to the directory containing the cache of prepared data" " has been specified. Data preparation will therefore be carried out" " on each process used for training. To speed up data preparation, you" - " can specify a cache directory when instantiating the task.", stacklevel=1) + " can specify a cache directory when instantiating the task.", + stacklevel=1) self.prepare_data() return # load data cached by prepare_data method into the task: From e7da160479a7f84be938f0ceec577c3acf89c9d2 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 8 Dec 2023 13:47:41 +0100 Subject: [PATCH 63/83] update: change name of attribute `database_ratio` to `dia_task_rate` --- .../joint_task/speaker_diarization_and_embedding.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index ce537c1c1..6c9b701af 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -91,7 +91,7 @@ def __init__( max_speakers_per_frame: int = 2, weigh_by_cardinality: bool = False, batch_size: int = 32, - database_ratio : float = 0.5, + dia_task_rate : float = 0.5, num_workers: int = None, pin_memory: bool = False, margin : float = 28.6, @@ -100,6 +100,7 @@ def __init__( augmentation: BaseWaveformTransform = None, cache_path: Optional[Union[str, None]] = None, ) -> None: + """TODO Add docstring""" super().__init__( protocol, duration=duration, @@ -113,7 +114,7 @@ def __init__( cache_path=cache_path, ) - self.database_ratio = database_ratio + self.dia_task_rate = dia_task_rate self.margin = margin self.scale = scale self.alpha = alpha @@ -369,8 +370,9 @@ def train__iter__helper(self, rng : random.Random, **filters): if np.any(annotated_duration != 0.): prob_annotated_duration = annotated_duration / np.sum(annotated_duration) else: - # There is only files for the embedding subtask - self.database_ratio = 0. + # There is only files for the embedding subtask, so only train on + # this task + self.dia_task_rate = 0. self.alpha = 0. duration = self.duration @@ -384,7 +386,7 @@ def train__iter__helper(self, rng : random.Random, **filters): while True: # choose between diarization or embedding subtask according to a ratio # between these two tasks - if np.random.uniform() < self.database_ratio: + if np.random.uniform() < self.dia_task_rate: file_id, start_time = self.draw_diarization_chunk(file_ids, prob_annotated_duration, rng, duration) From 77ac89fa397b6f8e37e2e950617b86e7c317d95b Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 8 Dec 2023 13:51:09 +0100 Subject: [PATCH 64/83] wip: attempt to fix issues encountered during training --- .../speaker_diarization_and_embedding.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 6c9b701af..f2e917188 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -359,13 +359,14 @@ def train__iter__helper(self, rng : random.Random, **filters): training &= self.prepared_data["metadata"][key] == value file_ids = np.where(training)[0] # get the subset of embedding database files from training files - embedding_files_ids = file_ids[np.in1d(file_ids, self.global_files_id)] + embedding_files_ids = file_ids[np.isin(file_ids, self.global_files_id)] annotated_duration = self.prepared_data["annotated_duration"][file_ids] # set duration of files for the embedding part to zero, in order to not # drawn them for diarization part annotated_duration[embedding_files_ids] = 0. - # test if there is at least one file for the diarization subtask to avoid + + # test if there is at least one file for the diarization subtask # to prevent probabilities from summing to zero if np.any(annotated_duration != 0.): prob_annotated_duration = annotated_duration / np.sum(annotated_duration) @@ -377,10 +378,10 @@ def train__iter__helper(self, rng : random.Random, **filters): duration = self.duration - # make a copy of the original class list, so as not to modify it during shuffling - embedding_classes = self.specifications[Subtasks.index("embedding")].classes # use original order for the first run of the shuffled classes list: - shuffled_embedding_classes = list(embedding_classes) + shuffled_embedding_classes = list( + self.specifications[Subtasks.index("embedding")].classes + ) embedding_class_idx = 0 while True: @@ -647,11 +648,14 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): """ # Get speaker representations from the embedding subtask - embeddings = emb_prediction[emb_chunks] + embeddings = rearrange(emb_prediction[emb_chunks], "b s e -> (b s) e") # Get corresponding target label - targets = target_emb[emb_chunks] + targets = rearrange(target_emb[emb_chunks], "b s -> (b s)") + # compute loss only on global scope speaker embedding + valid_emb = targets != -1 + # compute the loss - emb_loss = self.model.arc_face_loss(embeddings, targets) + emb_loss = self.model.arc_face_loss(embeddings[valid_emb, :], targets[valid_emb]) # skip batch if something went wrong for some reason if torch.isnan(emb_loss): @@ -689,6 +693,7 @@ def training_step(self, batch, batch_idx: int): target_dia = batch["y_dia"] # batch embedding references (batch, num_speakers) target_emb = batch["y_emb"] + meta = batch["meta"] # drop samples that contain too many speakers num_speakers: torch.Tensor = torch.sum(torch.any(target_dia, dim=1), dim=1) @@ -713,21 +718,19 @@ def training_step(self, batch, batch_idx: int): # filter out the speaker in the reference that were not found by the diarization # part of the model, to not compute the embedding loss on these speaker: - active_spk_mask = torch.any(rearrange(dia_multilabel, "b f s -> b s f"), dim=2) + # active_spk_mask = torch.any(rearrange(dia_multilabel, "b f s -> b s f"), dim=2) # (batch_size, num_spk) - emb_prediction = emb_prediction[active_spk_mask] + # emb_prediction = emb_prediction[active_spk_mask] # (num_active_spk_found_in_all_the_chunks, emb_size) - permutated_target_emb = permutated_target_emb[active_spk_mask] + # permutated_target_emb = permutated_target_emb[permutated_target_emb != 1] # (num_activate_spk_found,) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target_dia.float() ) - - # get embedding chunks position in current batch - emb_chunks = permutated_target_emb != -1 - # get diarization chunks position in current batch (that correspond to non embedding chunks) - dia_chunks = torch.nonzero(torch.all(target_emb == -1, axis=1)).reshape((-1,)) + # get embedding and diarization chunks position in current batch + emb_chunks = batch["meta"]["scope"] == 2 # global scope for embedding task + dia_chunks = batch["meta"]["scope"] < 2 # file and database scope for diarization task dia_loss = torch.tensor(0) #if batch contains diarization subtask chunks, then compute diarization loss on these chunks: From ea6d06dde8c258e6fcb4ba76c8185e8d266a72c8 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 8 Dec 2023 13:56:06 +0100 Subject: [PATCH 65/83] update: use all the `pyannet` pretrained model --- pyannote/audio/models/joint/end_to_end_diarization.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index 591e32836..453cf6bc8 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -69,14 +69,11 @@ def __init__( self.pyannet = Model.from_pretrained( "pyannote/segmentation-3.0", use_auth_token=os.environ["HUGGINGFACE_TOKEN"], - strict=False ) def build(self): """""" dia_specs = self.specifications[Subtasks.index("diarization")] - self.pyannet.specifications = dia_specs - self.pyannet.build() self.powerset = Powerset( len(dia_specs.classes), dia_specs.powerset_max_classes, @@ -91,7 +88,7 @@ def forward(self, waveformms: torch.Tensor) -> torch.Tensor: Batch of waveforms with shape (batch, channel, sample) """ dia_outputs = self.pyannet(waveformms) - weights = self.powerset.to_multilabel(dia_outputs) + weights = self.powerset.to_multilabel(dia_outputs, soft=True) weights = rearrange(weights, "b f s -> b s f") emb_outputs = self.resnet34(waveformms, weights) return (dia_outputs, emb_outputs) @@ -236,7 +233,6 @@ def build(self): diarization_spec = self.specifications[Subtasks.index("diarization")] out_features = diarization_spec.num_powerset_classes self.classifier = nn.Linear(in_features, out_features) - self.powerset = Powerset( len(diarization_spec.classes), diarization_spec.powerset_max_classes, From 185798de5157679313d32471fe00f72cde75c0ab Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 8 Dec 2023 15:23:02 +0100 Subject: [PATCH 66/83] fix: fix diarization loss calculation condition in `training_step` --- .../audio/tasks/joint_task/speaker_diarization_and_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index f2e917188..e3754fc5c 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -734,7 +734,7 @@ def training_step(self, batch, batch_idx: int): dia_loss = torch.tensor(0) #if batch contains diarization subtask chunks, then compute diarization loss on these chunks: - if dia_chunks.shape[0] > 0: + if dia_chunks.any(): dia_loss = self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) emb_loss = torch.tensor(0) From 9d13697d351d78696be9dcbadbc372ab380c4c0e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 14 May 2024 10:11:20 +0200 Subject: [PATCH 67/83] update joint task with last modifications on data preparation --- .../speaker_diarization_and_embedding.py | 271 +++++++++--------- 1 file changed, 137 insertions(+), 134 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index e3754fc5c..603134fb6 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -20,40 +20,33 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from collections import defaultdict -from einops import rearrange -import math import itertools +import math import random from typing import Dict, Literal, Optional, Sequence, Union -import warnings -from matplotlib import pyplot as plt + import numpy as np import torch - +from einops import rearrange +from matplotlib import pyplot as plt +from pyannote.core import Segment, SlidingWindowFeature +from pyannote.database.protocol.protocol import Scope, Subset +from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger +from pytorch_metric_learning.losses import ArcFaceLoss from torch_audiomentations.core.transforms_interface import BaseWaveformTransform -from torchaudio.backend.common import AudioMetaData from torchmetrics import Metric -from torchmetrics.classification import BinaryAUROC -from torch.utils.data._utils.collate import default_collate -from pytorch_metric_learning.losses import ArcFaceLoss -from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger -from pyannote.core import Segment, SlidingWindowFeature from pyannote.audio.core.task import Problem, Resolution, Specifications -from pyannote.audio.utils.loss import nll_loss -from pyannote.audio.utils.permutation import permutate -from pyannote.audio.utils.random import create_rng_for_worker from pyannote.audio.tasks import SpeakerDiarization -from pyannote.database.protocol import SpeakerDiarizationProtocol -from pyannote.database.protocol.protocol import Scope, Subset -from pyannote.audio.torchmetrics.classification import EqualErrorRate from pyannote.audio.torchmetrics import ( DiarizationErrorRate, FalseAlarmRate, MissedDetectionRate, SpeakerConfusionRate, ) +from pyannote.audio.utils.loss import nll_loss +from pyannote.audio.utils.permutation import permutate +from pyannote.audio.utils.random import create_rng_for_worker Subtask = Literal["diarization", "embedding"] @@ -84,21 +77,21 @@ class JointSpeakerDiarizationAndEmbedding(SpeakerDiarization): """ def __init__( - self, - protocol, - duration: float = 5.0, - max_speakers_per_chunk: int = 3, - max_speakers_per_frame: int = 2, - weigh_by_cardinality: bool = False, - batch_size: int = 32, - dia_task_rate : float = 0.5, - num_workers: int = None, - pin_memory: bool = False, - margin : float = 28.6, - scale: float = 64.0, - alpha: float = 0.5, - augmentation: BaseWaveformTransform = None, - cache_path: Optional[Union[str, None]] = None, + self, + protocol, + duration: float = 5.0, + max_speakers_per_chunk: int = 3, + max_speakers_per_frame: int = 2, + weigh_by_cardinality: bool = False, + batch_size: int = 32, + dia_task_rate: float = 0.5, + num_workers: int = None, + pin_memory: bool = False, + margin: float = 28.6, + scale: float = 64.0, + alpha: float = 0.5, + augmentation: BaseWaveformTransform = None, + cache: Optional[Union[str, None]] = None, ) -> None: """TODO Add docstring""" super().__init__( @@ -111,7 +104,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - cache_path=cache_path, + cache=cache, ) self.dia_task_rate = dia_task_rate @@ -121,31 +114,7 @@ def __init__( # keep track of the use of database available in the meta protocol # * embedding databases are those with global speaker label scope # * diarization databases are those with file or database speaker label scope - self.global_files_id = [] - - def get_file(self, file_id): - - file = dict() - - file["audio"] = str(self.prepared_data["audios"][file_id], encoding="utf-8") - - _audio_info = self.prepared_data["audio_infos"][file_id] - _encoding = self.prepared_data["audio_encodings"][file_id] - - sample_rate = _audio_info["sample_rate"] - num_frames = _audio_info["num_frames"] - num_channels = _audio_info["num_channels"] - bits_per_sample = _audio_info["bits_per_sample"] - encoding = str(_encoding, encoding="utf-8") - file["torchaudio.info"] = AudioMetaData( - sample_rate=sample_rate, - num_frames=num_frames, - num_channels=num_channels, - bits_per_sample=bits_per_sample, - encoding=encoding, - ) - - return file + self.embedding_files_id = [] def setup(self, stage="fit"): """Setup method @@ -157,23 +126,30 @@ def setup(self, stage="fit"): """ super().setup() - global_scope_mask = self.prepared_data["annotations"]["global_label_idx"] > -1 - self.global_files_id = np.unique(self.prepared_data["annotations"]["file_id"][global_scope_mask]) - global_classes = np.unique(self.prepared_data["annotations"]["global_label_idx"][global_scope_mask]) + + database_scope_mask = self.prepared_data["audio-metadata"]["scope"] > 0 + self.embedding_files_id = np.unique( + self.prepared_data["annotations-segments"]["file_id"][database_scope_mask] + ) + embedding_classes = np.unique( + self.prepared_data["annotations-segments"]["global_label_idx"][ + database_scope_mask + ] + ) speaker_diarization = Specifications( - duration=self.duration, - resolution=Resolution.FRAME, - problem=Problem.MONO_LABEL_CLASSIFICATION, - permutation_invariant=True, - classes=[f"speaker{i+1}" for i in range(self.max_speakers_per_chunk)], - powerset_max_classes=self.max_speakers_per_frame, - ) + duration=self.duration, + resolution=Resolution.FRAME, + problem=Problem.MONO_LABEL_CLASSIFICATION, + permutation_invariant=True, + classes=[f"speaker{i+1}" for i in range(self.max_speakers_per_chunk)], + powerset_max_classes=self.max_speakers_per_frame, + ) speaker_embedding = Specifications( duration=self.duration, resolution=Resolution.CHUNK, problem=Problem.REPRESENTATION, - classes=global_classes, + classes=embedding_classes, ) self.specifications = (speaker_diarization, speaker_embedding) @@ -208,16 +184,20 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): file = self.get_file(file_id) # get label scope - label_scope = Scopes[self.prepared_data["metadata"][file_id]["scope"]] + label_scope = Scopes[self.prepared_data["audio-metadata"][file_id]["scope"]] label_scope_key = f"{label_scope}_label_idx" chunk = Segment(start_time, start_time + duration) sample = dict() - sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration, mode="pad") + sample["X"], _ = self.model.audio.crop( + file, chunk, duration=duration, mode="pad" + ) # gather all annotations of current file - annotations = self.prepared_data["annotations"][self.prepared_data["annotations"]["file_id"] == file_id] + annotations = self.prepared_data["annotations-segments"][ + self.prepared_data["annotations-segments"]["file_id"] == file_id + ] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ @@ -227,7 +207,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start # TODO handle tuple outputs from the model - start_idx = np.floor(start / self.model.example_output[0].frames.step).astype(int) + start_idx = np.floor(start / self.model.example_output[0].frames.step).astype( + int + ) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start end_idx = np.ceil(end / self.model.example_output[0].frames.step).astype(int) @@ -239,7 +221,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((self.model.example_output[0].num_frames, num_labels), dtype=np.uint8) + y = np.zeros( + (self.model.example_output[0].num_frames, num_labels), dtype=np.uint8 + ) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -253,17 +237,19 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["y"] = SlidingWindowFeature( y, self.model.example_output[0].frames, labels=labels ) - metadata = self.prepared_data["metadata"][file_id] + metadata = self.prepared_data["audio-metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id return sample - def draw_diarization_chunk(self, file_ids : np.ndarray, - prob_annotated_duration : np.ndarray, - rng : random.Random, - duration : float, - ) -> tuple: + def draw_diarization_chunk( + self, + file_ids: np.ndarray, + prob_annotated_duration: np.ndarray, + rng: random.Random, + duration: float, + ) -> tuple: """Sample one chunk for the diarization task Parameters @@ -286,23 +272,28 @@ def draw_diarization_chunk(self, file_ids : np.ndarray, )[0] # turn annotated regions duration into a probability distribution - prob_annotaded_regions_duration = self.prepared_data["annotated_regions"]["duration"][ - annotated_region_indices - ] / np.sum(self.prepared_data["annotated_regions"]["duration"][annotated_region_indices]) + prob_annotaded_regions_duration = self.prepared_data["annotations-regions"][ + "duration" + ][annotated_region_indices] / np.sum( + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices + ] + ) # seletect one annotated region at random (with probability proportional to its duration) - annotated_region_index = np.random.choice(annotated_region_indices, - p=prob_annotaded_regions_duration - ) + annotated_region_index = np.random.choice( + annotated_region_indices, p=prob_annotaded_regions_duration + ) # select one chunk at random in this annotated region - _, region_duration, start = self.prepared_data["annotated_regions"][annotated_region_index] + _, region_duration, start = self.prepared_data["annotations-regions"][ + annotated_region_index + ] start_time = rng.uniform(start, start + region_duration - duration) return (file_id, start_time) - def draw_embedding_chunk(self, class_id : int, - duration : float) -> tuple: + def draw_embedding_chunk(self, class_id: int, duration: float) -> tuple: """Sample one chunk for the embedding task Parameters @@ -322,8 +313,10 @@ class ID in the task speficiations """ # get index of the current class in the order of original class list # get segments for current class - class_segments_idx = self.prepared_data["annotations"]["global_label_idx"] == class_id - class_segments = self.prepared_data["annotations"][class_segments_idx] + class_segments_idx = ( + self.prepared_data["annotations-segments"]["global_label_idx"] == class_id + ) + class_segments = self.prepared_data["annotations-segments"][class_segments_idx] # sample one segment from all the class segments: segments_duration = class_segments["end"] - class_segments["start"] @@ -332,11 +325,13 @@ class ID in the task speficiations segment = np.random.choice(class_segments, p=prob_segments) # sample chunk start time in order to intersect it with the sampled segment - start_time = np.random.uniform(max(segment["start"] - duration, 0), segment["end"]) + start_time = np.random.uniform( + max(segment["start"] - duration, 0), segment["end"] + ) return (segment["file_id"], start_time) - def train__iter__helper(self, rng : random.Random, **filters): + def train__iter__helper(self, rng: random.Random, **filters): """Iterate over training samples with optional domain filtering Parameters @@ -354,27 +349,29 @@ def train__iter__helper(self, rng : random.Random, **filters): """ # indices of training files that matches domain filters - training = self.prepared_data["metadata"]["subset"] == Subsets.index("train") + training = self.prepared_data["metadata-values"]["subset"] == Subsets.index( + "train" + ) for key, value in filters.items(): - training &= self.prepared_data["metadata"][key] == value + training &= self.prepared_data["metadata-values"][key] == value file_ids = np.where(training)[0] # get the subset of embedding database files from training files - embedding_files_ids = file_ids[np.isin(file_ids, self.global_files_id)] + embedding_files_ids = file_ids[np.isin(file_ids, self.embedding_files_id)] - annotated_duration = self.prepared_data["annotated_duration"][file_ids] + annotated_duration = self.prepared_data["audio-annotated"][file_ids] # set duration of files for the embedding part to zero, in order to not # drawn them for diarization part - annotated_duration[embedding_files_ids] = 0. + annotated_duration[embedding_files_ids] = 0.0 # test if there is at least one file for the diarization subtask # to prevent probabilities from summing to zero - if np.any(annotated_duration != 0.): + if np.any(annotated_duration != 0.0): prob_annotated_duration = annotated_duration / np.sum(annotated_duration) else: # There is only files for the embedding subtask, so only train on # this task - self.dia_task_rate = 0. - self.alpha = 0. + self.dia_task_rate = 0.0 + self.alpha = 0.0 duration = self.duration @@ -388,9 +385,9 @@ def train__iter__helper(self, rng : random.Random, **filters): # choose between diarization or embedding subtask according to a ratio # between these two tasks if np.random.uniform() < self.dia_task_rate: - file_id, start_time = self.draw_diarization_chunk(file_ids, prob_annotated_duration, - rng, - duration) + file_id, start_time = self.draw_diarization_chunk( + file_ids, prob_annotated_duration, rng, duration + ) else: # shuffle embedding classes list and go through this shuffled list # to make sure to see all the speakers during training @@ -428,8 +425,10 @@ def train__iter__(self): else: # create subchunks = dict() - for product in itertools.product([self.prepared_data["metadata_unique_values"][key] for key in balance]): - filters = {key : value for key, value in zip(balance, product)} + for product in itertools.product( + [self.prepared_data["metadata-values"][key] for key in balance] + ): + filters = {key: value for key, value in zip(balance, product)} subchunks[product] = self.train__iter__helper(rng, **filters) while True: @@ -489,17 +488,19 @@ def collate_y(self, batch) -> torch.Tensor: mode="constant", ) if b["meta"]["scope"] > 1: - y_emb[: num_speakers] = labels[:] + y_emb[:num_speakers] = labels[:] else: if b["meta"]["scope"] > 1: - y_emb[: num_speakers] = labels[:] + y_emb[:num_speakers] = labels[:] collated_y_dia.append(y_dia) collate_y_emb.append(y_emb) - return (torch.from_numpy(np.stack(collated_y_dia)), - torch.from_numpy(np.stack(collate_y_emb)).squeeze(1)) + return ( + torch.from_numpy(np.stack(collated_y_dia)), + torch.from_numpy(np.stack(collate_y_emb)).squeeze(1), + ) def collate_fn(self, batch, stage="train"): """Collate function used for most segmentation tasks @@ -615,8 +616,7 @@ def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target # Get the permutated reference corresponding to diarization subtask permutated_target_dia = permutated_target[dia_chunks] # Compute segmentation loss - dia_loss = self.segmentation_loss(chunks_prediction, - permutated_target_dia) + dia_loss = self.segmentation_loss(chunks_prediction, permutated_target_dia) self.model.log( "loss/train/dia", dia_loss, @@ -653,9 +653,11 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): targets = rearrange(target_emb[emb_chunks], "b s -> (b s)") # compute loss only on global scope speaker embedding valid_emb = targets != -1 - + # compute the loss - emb_loss = self.model.arc_face_loss(embeddings[valid_emb, :], targets[valid_emb]) + emb_loss = self.model.arc_face_loss( + embeddings[valid_emb, :], targets[valid_emb] + ) # skip batch if something went wrong for some reason if torch.isnan(emb_loss): @@ -693,7 +695,6 @@ def training_step(self, batch, batch_idx: int): target_dia = batch["y_dia"] # batch embedding references (batch, num_speakers) target_emb = batch["y_emb"] - meta = batch["meta"] # drop samples that contain too many speakers num_speakers: torch.Tensor = torch.sum(torch.any(target_dia, dim=1), dim=1) @@ -713,8 +714,9 @@ def training_step(self, batch, batch_idx: int): # get the best permutation dia_multilabel = self.model.powerset.to_multilabel(dia_prediction) permutated_target_dia, permut_map = permutate(dia_multilabel, target_dia) - permutated_target_emb = target_emb[torch.arange(target_emb.shape[0]).unsqueeze(1), - permut_map] + permutated_target_emb = target_emb[ + torch.arange(target_emb.shape[0]).unsqueeze(1), permut_map + ] # filter out the speaker in the reference that were not found by the diarization # part of the model, to not compute the embedding loss on these speaker: @@ -730,17 +732,23 @@ def training_step(self, batch, batch_idx: int): ) # get embedding and diarization chunks position in current batch emb_chunks = batch["meta"]["scope"] == 2 # global scope for embedding task - dia_chunks = batch["meta"]["scope"] < 2 # file and database scope for diarization task + dia_chunks = ( + batch["meta"]["scope"] < 2 + ) # file and database scope for diarization task dia_loss = torch.tensor(0) - #if batch contains diarization subtask chunks, then compute diarization loss on these chunks: + # if batch contains diarization subtask chunks, then compute diarization loss on these chunks: if dia_chunks.any(): - dia_loss = self.compute_diarization_loss(dia_chunks, dia_prediction, permutated_target_powerset) + dia_loss = self.compute_diarization_loss( + dia_chunks, dia_prediction, permutated_target_powerset + ) emb_loss = torch.tensor(0) # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: if emb_chunks.any(): - emb_loss = self.compute_embedding_loss(emb_chunks, emb_prediction, permutated_target_emb) + emb_loss = self.compute_embedding_loss( + emb_chunks, emb_prediction, permutated_target_emb + ) loss = alpha * dia_loss + (1 - alpha) * emb_loss return {"loss": loss} @@ -760,7 +768,6 @@ def validation_step(self, batch, batch_idx: int): # target target_dia = batch["y_dia"] # (batch_size, num_frames, num_speakers) - target_emb = batch["y_emb"] waveform = batch["X"] # (batch_size, num_channels, num_samples) @@ -770,7 +777,7 @@ def validation_step(self, batch, batch_idx: int): # target = target[keep] # forward pass - dia_prediction, emb_prediction = self.model(waveform) + dia_prediction, _ = self.model(waveform) batch_size, num_frames, _ = dia_prediction.shape multilabel = self.model.powerset.to_multilabel(dia_prediction) @@ -782,7 +789,8 @@ def validation_step(self, batch, batch_idx: int): permutated_target.float() ) seg_loss = self.segmentation_loss( - dia_prediction, permutated_target_powerset, + dia_prediction, + permutated_target_powerset, ) self.model.log( @@ -795,12 +803,8 @@ def validation_step(self, batch, batch_idx: int): ) self.model.validation_metric( - torch.transpose( - multilabel, 1, 2 - ), - torch.transpose( - target_dia, 1, 2 - ), + torch.transpose(multilabel, 1, 2), + torch.transpose(target_dia, 1, 2), ) self.model.log_dict( @@ -875,7 +879,6 @@ def validation_step(self, batch, batch_idx: int): plt.close(fig) - def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: @@ -887,6 +890,6 @@ def default_metric( "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - #"EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), - #"BinaryAUROC": BinaryAUROC(compute_on_cpu=True), + # "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), + # "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), } From 6c67fc69c69d7a0cf71e31bf37a194c607593156 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 14 May 2024 13:36:12 +0200 Subject: [PATCH 68/83] update the way batches are generated in the joint task Now, the first `num_dia_samples` samples in a batch are dedicated to the diarization substak, and the remaining sample are for the embedding subtask --- .../speaker_diarization_and_embedding.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 603134fb6..ecfc16ae4 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -107,7 +107,7 @@ def __init__( cache=cache, ) - self.dia_task_rate = dia_task_rate + self.num_dia_samples = int(batch_size * dia_task_rate) self.margin = margin self.scale = scale self.alpha = alpha @@ -127,13 +127,15 @@ def setup(self, stage="fit"): super().setup() - database_scope_mask = self.prepared_data["audio-metadata"]["scope"] > 0 + global_scope_mask = ( + self.prepared_data["annotations-segments"]["global_label_idx"] > -1 + ) self.embedding_files_id = np.unique( - self.prepared_data["annotations-segments"]["file_id"][database_scope_mask] + self.prepared_data["annotations-segments"]["file_id"][global_scope_mask] ) embedding_classes = np.unique( self.prepared_data["annotations-segments"]["global_label_idx"][ - database_scope_mask + global_scope_mask ] ) @@ -370,7 +372,7 @@ def train__iter__helper(self, rng: random.Random, **filters): else: # There is only files for the embedding subtask, so only train on # this task - self.dia_task_rate = 0.0 + self.num_dia_samples = 0.0 self.alpha = 0.0 duration = self.duration @@ -379,12 +381,12 @@ def train__iter__helper(self, rng: random.Random, **filters): shuffled_embedding_classes = list( self.specifications[Subtasks.index("embedding")].classes ) + + sample_idx = 0 embedding_class_idx = 0 while True: - # choose between diarization or embedding subtask according to a ratio - # between these two tasks - if np.random.uniform() < self.dia_task_rate: + if sample_idx < self.num_dia_samples: file_id, start_time = self.draw_diarization_chunk( file_ids, prob_annotated_duration, rng, duration ) @@ -399,6 +401,7 @@ def train__iter__helper(self, rng: random.Random, **filters): file_id, start_time = self.draw_embedding_chunk(klass, duration) sample = self.prepare_chunk(file_id, start_time, duration) + sample_idx = (sample_idx + 1) % self.batch_size yield sample def train__iter__(self): From 519db893cc9a8e0dfebc6b3610109e83a6acfb13 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 14 May 2024 14:52:03 +0200 Subject: [PATCH 69/83] fix random generators --- .../speaker_diarization_and_embedding.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index ecfc16ae4..f31e1022b 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -248,7 +248,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): def draw_diarization_chunk( self, file_ids: np.ndarray, - prob_annotated_duration: np.ndarray, + cum_prob_annotated_duration: np.ndarray, rng: random.Random, duration: float, ) -> tuple: @@ -258,7 +258,7 @@ def draw_diarization_chunk( ---------- file_ids: np.ndarray array containing files id - prob_annotated_duration: np.ndarray + cum_prob_annotated_duration: np.ndarray array of the same size than file_ids array, containing probability to corresponding file to be drawn rng : random.Random @@ -267,25 +267,28 @@ def draw_diarization_chunk( duration of the chunk to draw """ # select one file at random (wiht probability proportional to its annotated duration) - file_id = np.random.choice(file_ids, p=prob_annotated_duration) + file_id = file_ids[cum_prob_annotated_duration.searchsorted(rng.random())] # find indices of annotated regions in this file annotated_region_indices = np.where( - self.prepared_data["annotated_regions"]["file_id"] == file_id + self.prepared_data["annotations-regions"]["file_id"] == file_id )[0] # turn annotated regions duration into a probability distribution - prob_annotaded_regions_duration = self.prepared_data["annotations-regions"][ - "duration" - ][annotated_region_indices] / np.sum( + cum_prob_annotaded_regions_duration = np.cumsum( self.prepared_data["annotations-regions"]["duration"][ annotated_region_indices ] + / np.sum( + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices + ] + ) ) # seletect one annotated region at random (with probability proportional to its duration) - annotated_region_index = np.random.choice( - annotated_region_indices, p=prob_annotaded_regions_duration - ) + annotated_region_index = annotated_region_indices[ + cum_prob_annotaded_regions_duration.searchsorted(rng.random()) + ] # select one chunk at random in this annotated region _, region_duration, start = self.prepared_data["annotations-regions"][ @@ -351,11 +354,13 @@ def train__iter__helper(self, rng: random.Random, **filters): """ # indices of training files that matches domain filters - training = self.prepared_data["metadata-values"]["subset"] == Subsets.index( + training = self.prepared_data["audio-metadata"]["subset"] == Subsets.index( "train" ) for key, value in filters.items(): - training &= self.prepared_data["metadata-values"][key] == value + training &= self.prepared_data["audio-metadata"][key] == self.prepared_data[ + "metadata" + ][key].index(value) file_ids = np.where(training)[0] # get the subset of embedding database files from training files embedding_files_ids = file_ids[np.isin(file_ids, self.embedding_files_id)] @@ -368,7 +373,9 @@ def train__iter__helper(self, rng: random.Random, **filters): # test if there is at least one file for the diarization subtask # to prevent probabilities from summing to zero if np.any(annotated_duration != 0.0): - prob_annotated_duration = annotated_duration / np.sum(annotated_duration) + cum_prob_annotated_duration = np.cumsum( + annotated_duration / np.sum(annotated_duration) + ) else: # There is only files for the embedding subtask, so only train on # this task @@ -388,7 +395,7 @@ def train__iter__helper(self, rng: random.Random, **filters): while True: if sample_idx < self.num_dia_samples: file_id, start_time = self.draw_diarization_chunk( - file_ids, prob_annotated_duration, rng, duration + file_ids, cum_prob_annotated_duration, rng, duration ) else: # shuffle embedding classes list and go through this shuffled list @@ -420,7 +427,7 @@ def train__iter__(self): """ # create worker-specific random number generator - rng = create_rng_for_worker(self.model.current_epoch) + rng = create_rng_for_worker(self.model) balance = getattr(self, "balance", None) if balance is None: From 106bfc538ddac6dc6909dfea10236b0e543c4808 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 14 May 2024 14:54:44 +0200 Subject: [PATCH 70/83] delete remaining call to `example_output` --- .../models/joint/end_to_end_diarization.py | 180 ++++++++++++------ .../speaker_diarization_and_embedding.py | 30 ++- 2 files changed, 140 insertions(+), 70 deletions(-) diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index 453cf6bc8..bfc0bec65 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -21,22 +21,22 @@ # SOFTWARE. import os -from typing import List, Literal, Optional, Union +from functools import lru_cache +from typing import Literal, Optional, Union from warnings import warn -from einops import rearrange import torch -from torch import nn import torch.nn.functional as F +from einops import rearrange +from pyannote.core.utils.generators import pairwise +from torch import nn from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task -from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.models.blocks.pooling import StatsPool +from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.utils.params import merge_dict from pyannote.audio.utils.powerset import Powerset -from pyannote.core.utils.generators import pairwise - # TODO deplace these two lines into uitls/multi_task Subtask = Literal["diarization", "embedding"] @@ -48,16 +48,17 @@ class WeSpeakerBasesEndToEndDiarization(Model): WeSpeaker-based joint speaker diarization and speaker embedding extraction model """ + def __init__( - self, - sincnet: dict = None, - lstm: dict = None, - linear: dict = None, - sample_rate=16000, - embedding_dim=256, - num_channels=1, - task: Optional[Union[Task, None]] = None, - ): + self, + sincnet: dict = None, + lstm: dict = None, + linear: dict = None, + sample_rate=16000, + embedding_dim=256, + num_channels=1, + task: Optional[Union[Task, None]] = None, + ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) # speakers embedding extraction submodel: @@ -68,9 +69,20 @@ def __init__( # speaker segmentation submodel: self.pyannet = Model.from_pretrained( "pyannote/segmentation-3.0", - use_auth_token=os.environ["HUGGINGFACE_TOKEN"], + use_auth_token=os.environ["HG_TOKEN"], ) + @property + def dimension(self) -> int: + """Dimension of output""" + if isinstance(self.specifications, tuple): + raise ValueError("PyanNet does not support multi-tasking.") + + if self.specifications.powerset: + return self.specifications.num_powerset_classes + else: + return len(self.specifications.classes) + def build(self): """""" dia_specs = self.specifications[Subtasks.index("diarization")] @@ -79,6 +91,54 @@ def build(self): dia_specs.powerset_max_classes, ) + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames for a given number of input samples + + Parameters + ---------- + num_samples : int + Number of input samples + + Returns + ------- + num_frames : int + Number of output frames + """ + + return self.pyannet.sincnet.num_frames(num_samples) + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + return self.pyannet.sincnet.receptive_field_size(num_frames=num_frames) + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + + return self.pyannet.sincnet.receptive_field_center(frame=frame) + def forward(self, waveformms: torch.Tensor) -> torch.Tensor: """ @@ -97,8 +157,9 @@ def forward(self, waveformms: torch.Tensor) -> torch.Tensor: class SpeakerEndToEndDiarization(Model): """Speaker End-to-End Diarization and Embedding model SINCNET -- TDNN .. TDNN -- TDNN ..TDNN -- StatsPool -- Linear -- Classifier - \ LSTM ... LSTM -- FeedForward -- Classifier + \\ LSTM ... LSTM -- FeedForward -- Classifier """ + SINCNET_DEFAULTS = {"stride": 10} LSTM_DEFAULTS = { "hidden_size": 128, @@ -111,31 +172,32 @@ class SpeakerEndToEndDiarization(Model): LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} def __init__( - self, - sincnet: dict = None, - lstm: dict= None, - linear: dict = None, - sample_rate: int = 16000, - num_channels: int = 1, - num_features: int = 60, - embedding_dim: int = 512, - separation_idx: int = 2, - task: Optional[Task] = None, - ): + self, + sincnet: dict = None, + lstm: dict = None, + linear: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + num_features: int = 60, + embedding_dim: int = 512, + separation_idx: int = 2, + task: Optional[Task] = None, + ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) if num_features != 60: - warn("For now, the model only support a number of features of 60. Set it to 60") + warn( + "For now, the model only support a number of features of 60. Set it to 60" + ) num_features = 60 self.num_features = num_features self.separation_idx = separation_idx self.save_hyperparameters("num_features", "embedding_dim", "separation_idx") - # sincnet module sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) sincnet["sample_rate"] = sample_rate - self.sincnet =SincNet(**sincnet) + self.sincnet = SincNet(**sincnet) self.save_hyperparameters("sincnet") # tdnn modules @@ -238,7 +300,9 @@ def build(self): diarization_spec.powerset_max_classes, ) - def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Parameters @@ -261,7 +325,9 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non rearrange(common_outputs, "batch feature frame -> batch frame feature") ) else: - diarization_outputs = rearrange(common_outputs, "batch feature frame -> batch frame feature") + diarization_outputs = rearrange( + common_outputs, "batch feature frame -> batch frame feature" + ) for i, lstm in enumerate(self.lstm): diarization_outputs, _ = lstm(diarization_outputs) if i + 1 < self.hparams.lstm["num_layers"]: @@ -299,31 +365,32 @@ class SpeakerEndToEndDiarizationV2(Model): LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} def __init__( - self, - sincnet: dict = None, - lstm: dict= None, - linear: dict = None, - sample_rate: int = 16000, - num_channels: int = 1, - num_features: int = 60, - embedding_dim: int = 512, - separation_idx: int = 2, - task: Optional[Task] = None, - ): + self, + sincnet: dict = None, + lstm: dict = None, + linear: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + num_features: int = 60, + embedding_dim: int = 512, + separation_idx: int = 2, + task: Optional[Task] = None, + ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) if num_features != 60: - warn("For now, the model only support a number of features of 60. Set it to 60") + warn( + "For now, the model only support a number of features of 60. Set it to 60" + ) num_features = 60 self.num_features = num_features self.separation_idx = separation_idx self.save_hyperparameters("num_features", "embedding_dim", "separation_idx") - # sincnet module sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) sincnet["sample_rate"] = sample_rate - self.sincnet =SincNet(**sincnet) + self.sincnet = SincNet(**sincnet) self.save_hyperparameters("sincnet") # tdnn modules @@ -345,7 +412,7 @@ def __init__( out_channels=out_channel, kernel_size=kernel_size, dilation=dilation, - padding="same" + padding="same", ), nn.LeakyReLU(), nn.BatchNorm1d(out_channel), @@ -431,13 +498,15 @@ def build(self): ) self.encoder = nn.LSTM( # number of channel in the outputs of the last TDNN layer + lstm_out_features - input_size= self.last_tdnn_out_channels + lstm_out_features, - hidden_size= len(diarization_spec.classes) * self.last_tdnn_out_channels, + input_size=self.last_tdnn_out_channels + lstm_out_features, + hidden_size=len(diarization_spec.classes) * self.last_tdnn_out_channels, batch_first=True, bidirectional=True, ) - def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Parameters @@ -462,7 +531,9 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non rearrange(dia_outputs, "batch feature frame -> batch frame feature") ) else: - dia_outputs = rearrange(common_outputs, "batch feature frame -> batch frame feature") + dia_outputs = rearrange( + common_outputs, "batch feature frame -> batch frame feature" + ) for i, lstm in enumerate(self.lstm): dia_outputs, _ = lstm(dia_outputs) if i + 1 < self.hparams.lstm["num_layers"]: @@ -484,8 +555,9 @@ def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = Non emb_outputs = torch.cat((emb_outputs, lstm_outputs), dim=2) _, emb_outputs = self.encoder(emb_outputs) emb_outputs = rearrange(emb_outputs[0], "l b h -> b (l h)") - emb_outputs = torch.reshape(emb_outputs, - (emb_outputs.shape[0], self.powerset.num_classes, -1)) + emb_outputs = torch.reshape( + emb_outputs, (emb_outputs.shape[0], self.powerset.num_classes, -1) + ) emb_outputs = self.embedding(emb_outputs) return (dia_outputs, emb_outputs) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index f31e1022b..6c2ce0873 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -166,9 +166,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): Chunk start time duration : float Chunk duration. - subtask: int - - 0 : diarization task - - 1 : embedding task Returns ------- @@ -207,13 +204,14 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): ] # discretize chunk annotations at model output resolution - start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - # TODO handle tuple outputs from the model - start_idx = np.floor(start / self.model.example_output[0].frames.step).astype( - int - ) - end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / self.model.example_output[0].frames.step).astype(int) + step = self.model.receptive_field().step + half = 0.5 * self.model.receptive_field().duration + + start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - half + start_idx = np.maximum(0, np.round(start / step)).astype(int) + + end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - half + end_idx = np.round(end / step).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -223,9 +221,10 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros( - (self.model.example_output[0].num_frames, num_labels), dtype=np.uint8 + num_frames = self.model.num_frames( + round(duration * self.model.hparams.sample_rate) ) + y = np.zeros((num_frames, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -234,11 +233,10 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): start_idx, end_idx, chunk_annotations[label_scope_key] ): mapped_label = mapping[label] - y[start:end, mapped_label] = 1 + y[start : end + 1, mapped_label] = 1 + + sample["y"] = SlidingWindowFeature(y, self.model.receptive_field, labels=labels) - sample["y"] = SlidingWindowFeature( - y, self.model.example_output[0].frames, labels=labels - ) metadata = self.prepared_data["audio-metadata"][file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} sample["meta"]["file"] = file_id From d3326b1aac02993c8d21c3b1b2ccd2c936718d69 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Wed, 15 May 2024 09:29:36 +0200 Subject: [PATCH 71/83] update joint task `training_step` ... and fix some bugs --- .../speaker_diarization_and_embedding.py | 105 +++++++++--------- 1 file changed, 50 insertions(+), 55 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 6c2ce0873..7bea2ca08 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -23,6 +23,7 @@ import itertools import math import random +import warnings from typing import Dict, Literal, Optional, Sequence, Union import numpy as np @@ -139,6 +140,23 @@ def setup(self, stage="fit"): ] ) + # if there is no file dedicated to the embedding task + if self.alpha != 1.0 and len(embedding_classes) == 0: + self.num_dia_samples = self.batch_size + self.alpha = 1.0 + warnings.warn( + "No class found for the speaker embedding task. Model will be trained on the speaker diarization task only." + ) + + if self.alpha != 0.0 and np.sum(global_scope_mask) == len( + self.prepared_data["annotations-segments"] + ): + self.num_dia_samples = 0 + self.alpha = 0.0 + warnings.warn( + "No segment found for the speaker diarization task. Model will be trained on the speaker embedding task only." + ) + speaker_diarization = Specifications( duration=self.duration, resolution=Resolution.FRAME, @@ -363,33 +381,24 @@ def train__iter__helper(self, rng: random.Random, **filters): # get the subset of embedding database files from training files embedding_files_ids = file_ids[np.isin(file_ids, self.embedding_files_id)] - annotated_duration = self.prepared_data["audio-annotated"][file_ids] - # set duration of files for the embedding part to zero, in order to not - # drawn them for diarization part - annotated_duration[embedding_files_ids] = 0.0 + if self.num_dia_samples > 0: + annotated_duration = self.prepared_data["audio-annotated"][file_ids] + # set duration of files for the embedding part to zero, in order to not + # drawn them for diarization part + annotated_duration[embedding_files_ids] = 0.0 - # test if there is at least one file for the diarization subtask - # to prevent probabilities from summing to zero - if np.any(annotated_duration != 0.0): cum_prob_annotated_duration = np.cumsum( annotated_duration / np.sum(annotated_duration) ) - else: - # There is only files for the embedding subtask, so only train on - # this task - self.num_dia_samples = 0.0 - self.alpha = 0.0 duration = self.duration + batch_size = self.batch_size - # use original order for the first run of the shuffled classes list: - shuffled_embedding_classes = list( - self.specifications[Subtasks.index("embedding")].classes - ) + # use original order for the first run on the shuffled classes list: + emb_task_classes = self.specifications[Subtasks.index("embedding")].classes[:] sample_idx = 0 embedding_class_idx = 0 - while True: if sample_idx < self.num_dia_samples: file_id, start_time = self.draw_diarization_chunk( @@ -398,15 +407,16 @@ def train__iter__helper(self, rng: random.Random, **filters): else: # shuffle embedding classes list and go through this shuffled list # to make sure to see all the speakers during training - if embedding_class_idx == len(shuffled_embedding_classes): - rng.shuffle(shuffled_embedding_classes) + if embedding_class_idx == len(emb_task_classes): + rng.shuffle(emb_task_classes) embedding_class_idx = 0 - klass = shuffled_embedding_classes[embedding_class_idx] + klass = emb_task_classes[embedding_class_idx] embedding_class_idx += 1 file_id, start_time = self.draw_embedding_chunk(klass, duration) sample = self.prepare_chunk(file_id, start_time, duration) - sample_idx = (sample_idx + 1) % self.batch_size + sample_idx = (sample_idx + 1) % batch_size + yield sample def train__iter__(self): @@ -599,15 +609,12 @@ def segmentation_loss( return seg_loss - def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target): + def compute_diarization_loss(self, prediction, permutated_target): """Compute loss for the speaker diarization subtask Parameters ---------- - dia_chunks : torch.Tensor - tensor specifying the chunks assigned to the speaker diarization - task in the current batch. Shape of (batch_size,) - dia_prediction : torch.Tensor + prediction : torch.Tensor speaker diarization output predicted by the model for the current batch. Shape of (batch_size, num_spk, num_frames) permutated_target: torch.Tensor @@ -619,12 +626,8 @@ def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target Permutation-invariant diarization loss """ - # Get chunks corresponding to the diarization subtask - chunks_prediction = dia_prediction[dia_chunks] - # Get the permutated reference corresponding to diarization subtask - permutated_target_dia = permutated_target[dia_chunks] # Compute segmentation loss - dia_loss = self.segmentation_loss(chunks_prediction, permutated_target_dia) + dia_loss = self.segmentation_loss(prediction, permutated_target) self.model.log( "loss/train/dia", dia_loss, @@ -635,14 +638,11 @@ def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target ) return dia_loss - def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): + def compute_embedding_loss(self, emb_prediction, target_emb): """Compute loss for the speaker embeddings extraction subtask Parameters ---------- - emb_chunks : torch.Tensor - tensor specifying the chunks assigned to the speaker embeddings extraction - task in the current batch. Shape of (batch_size,) emb_prediction : torch.Tensor speaker embeddings predicted by the model for the current batch. Shape of (batch_size * num_spk, embedding_dim) @@ -656,9 +656,9 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb): """ # Get speaker representations from the embedding subtask - embeddings = rearrange(emb_prediction[emb_chunks], "b s e -> (b s) e") + embeddings = rearrange(emb_prediction, "b s e -> (b s) e") # Get corresponding target label - targets = rearrange(target_emb[emb_chunks], "b s -> (b s)") + targets = rearrange(target_emb, "b s -> (b s)") # compute loss only on global scope speaker embedding valid_emb = targets != -1 @@ -707,6 +707,8 @@ def training_step(self, batch, batch_idx: int): # drop samples that contain too many speakers num_speakers: torch.Tensor = torch.sum(torch.any(target_dia, dim=1), dim=1) keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + + num_remaining_dia_samples = torch.sum(keep[: self.num_dia_samples]) target_dia = target_dia[keep] target_emb = target_emb[keep] waveform = waveform[keep] @@ -726,36 +728,29 @@ def training_step(self, batch, batch_idx: int): torch.arange(target_emb.shape[0]).unsqueeze(1), permut_map ] - # filter out the speaker in the reference that were not found by the diarization - # part of the model, to not compute the embedding loss on these speaker: - # active_spk_mask = torch.any(rearrange(dia_multilabel, "b f s -> b s f"), dim=2) - # (batch_size, num_spk) - # emb_prediction = emb_prediction[active_spk_mask] - # (num_active_spk_found_in_all_the_chunks, emb_size) - # permutated_target_emb = permutated_target_emb[permutated_target_emb != 1] - # (num_activate_spk_found,) - permutated_target_powerset = self.model.powerset.to_powerset( permutated_target_dia.float() ) - # get embedding and diarization chunks position in current batch - emb_chunks = batch["meta"]["scope"] == 2 # global scope for embedding task - dia_chunks = ( - batch["meta"]["scope"] < 2 - ) # file and database scope for diarization task + + dia_prediction = dia_prediction[:num_remaining_dia_samples] + permutated_target_powerset = permutated_target_powerset[ + :num_remaining_dia_samples + ] dia_loss = torch.tensor(0) # if batch contains diarization subtask chunks, then compute diarization loss on these chunks: - if dia_chunks.any(): + if self.alpha != 0.0 and torch.any(keep[: self.num_dia_samples]): dia_loss = self.compute_diarization_loss( - dia_chunks, dia_prediction, permutated_target_powerset + dia_prediction, permutated_target_powerset ) emb_loss = torch.tensor(0) # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: - if emb_chunks.any(): + if self.alpha != 1.0 and torch.any(keep[self.num_dia_samples :]): + emb_prediction = emb_prediction[num_remaining_dia_samples:] + permutated_target_emb = permutated_target_emb[num_remaining_dia_samples:] emb_loss = self.compute_embedding_loss( - emb_chunks, emb_prediction, permutated_target_emb + emb_prediction, permutated_target_emb ) loss = alpha * dia_loss + (1 - alpha) * emb_loss From a36420d59472ac5e3f582e2b792df2526406905e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 27 May 2024 11:29:36 +0200 Subject: [PATCH 72/83] fix(task): fiw wrong call to `receptive_field` in `prepare_chunk` --- .../tasks/joint_task/speaker_diarization_and_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 7bea2ca08..872042d52 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -222,8 +222,8 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): ] # discretize chunk annotations at model output resolution - step = self.model.receptive_field().step - half = 0.5 * self.model.receptive_field().duration + step = self.model.receptive_field.step + half = 0.5 * self.model.receptive_field.duration start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - half start_idx = np.maximum(0, np.round(start / step)).astype(int) From 62fad7809bb341396db6c5d3692323f76b9cf767 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 28 May 2024 11:24:44 +0200 Subject: [PATCH 73/83] update(joint task): filter out inactive speaker embeddings from loss computation --- .../speaker_diarization_and_embedding.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 872042d52..b0db42371 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -638,7 +638,7 @@ def compute_diarization_loss(self, prediction, permutated_target): ) return dia_loss - def compute_embedding_loss(self, emb_prediction, target_emb): + def compute_embedding_loss(self, emb_prediction, target_emb, valid_embs): """Compute loss for the speaker embeddings extraction subtask Parameters @@ -660,11 +660,10 @@ def compute_embedding_loss(self, emb_prediction, target_emb): # Get corresponding target label targets = rearrange(target_emb, "b s -> (b s)") # compute loss only on global scope speaker embedding - valid_emb = targets != -1 - + valid_embs = rearrange(valid_embs, "b s -> (b s)") # compute the loss emb_loss = self.model.arc_face_loss( - embeddings[valid_emb, :], targets[valid_emb] + embeddings[valid_embs, :], targets[valid_embs] ) # skip batch if something went wrong for some reason @@ -696,7 +695,7 @@ def training_step(self, batch, batch_idx: int): loss : {str: torch.tensor} {"loss": loss} """ - alpha = self.alpha + # batch waveforms (batch_size, num_channels, num_samples) waveform = batch["X"] # batch diarization references (batch_size, num_channels, num_speakers) @@ -705,14 +704,15 @@ def training_step(self, batch, batch_idx: int): target_emb = batch["y_emb"] # drop samples that contain too many speakers - num_speakers: torch.Tensor = torch.sum(torch.any(target_dia, dim=1), dim=1) - keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk + num_speakers = torch.sum(torch.any(target_dia, dim=1), dim=1) + keep = num_speakers <= self.max_speakers_per_chunk - num_remaining_dia_samples = torch.sum(keep[: self.num_dia_samples]) target_dia = target_dia[keep] target_emb = target_emb[keep] waveform = waveform[keep] + num_remaining_dia_samples = torch.sum(keep[: self.num_dia_samples]) + # corner case if not keep.any(): return None @@ -728,6 +728,11 @@ def training_step(self, batch, batch_idx: int): torch.arange(target_emb.shape[0]).unsqueeze(1), permut_map ] + # an embedding is valid only if corresponding speaker is active in the diarization prediction and reference + active_speaker_pred = torch.any(dia_multilabel > 0, dim=1) + active_speaker_ref = torch.any(permutated_target_dia == 1, dim=1) + valid_embs = torch.logical_and(active_speaker_pred, active_speaker_ref)[num_remaining_dia_samples:] + permutated_target_powerset = self.model.powerset.to_powerset( permutated_target_dia.float() ) @@ -750,10 +755,10 @@ def training_step(self, batch, batch_idx: int): emb_prediction = emb_prediction[num_remaining_dia_samples:] permutated_target_emb = permutated_target_emb[num_remaining_dia_samples:] emb_loss = self.compute_embedding_loss( - emb_prediction, permutated_target_emb + emb_prediction, permutated_target_emb, valid_embs ) - loss = alpha * dia_loss + (1 - alpha) * emb_loss + loss = self.alpha * dia_loss + (1 - self.alpha) * emb_loss return {"loss": loss} # TODO: no need to compute gradient in this method From 8349818378b7d45a1b8bc8e729a0057597a2e1e4 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 21 Jun 2024 14:54:59 +0200 Subject: [PATCH 74/83] allow to only compute mean or std in `StatsPool` --- pyannote/audio/models/blocks/pooling.py | 45 ++++++++++++++++++++----- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/models/blocks/pooling.py b/pyannote/audio/models/blocks/pooling.py index dc31bea8e..17c2f9030 100644 --- a/pyannote/audio/models/blocks/pooling.py +++ b/pyannote/audio/models/blocks/pooling.py @@ -28,7 +28,9 @@ import torch.nn.functional as F -def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: +def _pool( + sequences: torch.Tensor, weights: torch.Tensor, compute_mean: bool, compute_std:bool + ) -> torch.Tensor: """Helper function to compute statistics pooling Assumes that weights are already interpolated to match the number of frames @@ -50,16 +52,24 @@ def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: weights = weights.unsqueeze(dim=1) # (batch, 1, frames) + stats = [] + v1 = weights.sum(dim=2) + 1e-8 mean = torch.sum(sequences * weights, dim=2) / v1 - dx2 = torch.square(sequences - mean.unsqueeze(2)) - v2 = torch.square(weights).sum(dim=2) + if compute_mean: + stats.append(mean) + + if compute_std: + dx2 = torch.square(sequences - mean.unsqueeze(2)) + v2 = torch.square(weights).sum(dim=2) - var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) - std = torch.sqrt(var) + var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) + std = torch.sqrt(var) - return torch.cat([mean, std], dim=1) + stats.append(std) + + return torch.cat(stats, dim=1) class StatsPool(nn.Module): @@ -68,14 +78,33 @@ class StatsPool(nn.Module): Compute temporal mean and (unbiased) standard deviation and returns their concatenation. + Parameters + ---------- + + compute_mean: bool, optional + whether to compute (and return) temporal mean. + Default to True + compute_std: bool, optional + whether to compute (and return) temporal standard deviation. + Default to True + Reference --------- https://en.wikipedia.org/wiki/Weighted_arithmetic_mean """ + def __init__( + self, + compute_mean: Optional[bool] = True, + computde_std: Optional[bool] = True, + ): + super().__init__() + self.compute_mean = compute_mean + self.compute_std = computde_std + def forward( - self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None + self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass @@ -122,7 +151,7 @@ def forward( output = torch.stack( [ - _pool(sequences, weights[:, speaker, :]) + _pool(sequences, weights[:, speaker, :], self.compute_mean, self.compute_std) for speaker in range(num_speakers) ], dim=1, From 0858227aefcc9da53340e8ab9184662f03afb8df Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 21 Jun 2024 14:56:10 +0200 Subject: [PATCH 75/83] update diarization + embeddings joint task --- .../speaker_diarization_and_embedding.py | 321 +++++++++++++++++- 1 file changed, 316 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index b0db42371..76833c997 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -20,10 +20,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from collections import defaultdict import itertools import math +from pathlib import Path import random import warnings +from tempfile import mkstemp from typing import Dict, Literal, Optional, Sequence, Union import numpy as np @@ -37,7 +40,9 @@ from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.core.task import ( + Problem, Resolution, Specifications, get_dtype +) from pyannote.audio.tasks import SpeakerDiarization from pyannote.audio.torchmetrics import ( DiarizationErrorRate, @@ -117,6 +122,305 @@ def __init__( # * diarization databases are those with file or database speaker label scope self.embedding_files_id = [] + def prepare_data(self): + """Use this to prepare data from task protocol + + Notes + ----- + Called only once on the main process (and only on it), for global_rank 0. + + After this method is called, the task should have a `prepared_data` attribute + with the following dictionary structure: + + prepared_data = { + 'protocol': name of the protocol + 'audio-path': array of N paths to audio + 'audio-metadata': array of N audio infos such as audio subset, scope and database + 'audio-info': array of N audio torchaudio.info struct + 'audio-encoding': array of N audio encodings + 'audio-annotated': array of N annotated duration (usually equals file duration but might be shorter if file is not fully annotated) + 'annotations-regions': array of M annotated regions + 'annotations-segments': array of M' annotated segments + 'metadata-values': dict of lists of values for subset, scope and database + 'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array. + 'metadata-labels': array of global scope labels + } + + """ + + if self.cache: + # check if cache exists and is not empty: + if self.cache.exists() and self.cache.stat().st_size > 0: + # data was already created, nothing to do + return + # create parent directory if needed + self.cache.parent.mkdir(parents=True, exist_ok=True) + else: + # if no cache was provided by user, create a temporary file + # in system directory used for temp files + self.cache = Path(mkstemp()[1]) + + # list of possible values for each metadata key + # (will become .prepared_data[""]) + metadata_unique_values = defaultdict(list) + metadata_unique_values["subset"] = Subsets + metadata_unique_values["scope"] = Scopes + + audios = list() # list of path to audio files + audio_infos = list() + audio_encodings = list() + metadata = list() # list of metadata + + annotated_duration = list() # total duration of annotated regions (per file) + annotated_regions = list() # annotated regions + annotations = list() # actual annotations + unique_labels = list() + database_unique_labels = {} + + if self.has_validation: + files_iter = itertools.chain( + zip(itertools.repeat("train"), self.protocol.train()), + zip(itertools.repeat("development"), self.protocol.development()), + ) + else: + files_iter = zip(itertools.repeat("train"), self.protocol.train()) + + for file_id, (subset, file) in enumerate(files_iter): + # gather metadata and update metadata_unique_values so that each metadatum + # (e.g. source database or label) is represented by an integer. + metadatum = dict() + + # keep track of source database and subset (train, development, or test) + if file["database"] not in metadata_unique_values["database"]: + metadata_unique_values["database"].append(file["database"]) + metadatum["database"] = metadata_unique_values["database"].index( + file["database"] + ) + + metadatum["subset"] = Subsets.index(subset) + + # keep track of label scope (file, database, or global) + metadatum["scope"] = Scopes.index(file["scope"]) + + remaining_metadata_keys = set(file) - set( + [ + "uri", + "database", + "subset", + "audio", + "torchaudio.info", + "scope", + "classes", + "annotation", + "annotated", + ] + ) + + # keep track of any other (integer or string) metadata provided by the protocol + # (e.g. a "domain" key for domain-adversarial training) + for key in remaining_metadata_keys: + value = file[key] + + if isinstance(value, str): + if value not in metadata_unique_values[key]: + metadata_unique_values[key].append(value) + metadatum[key] = metadata_unique_values[key].index(value) + + elif isinstance(value, int): + metadatum[key] = value + + else: + warnings.warn( + f"Ignoring '{key}' metadata because of its type ({type(value)}). Only str and int are supported for now.", + category=UserWarning, + ) + + metadata.append(metadatum) + + # reset list of file-scoped labels + file_unique_labels = list() + + # path to audio file + audios.append(str(file["audio"])) + + # audio info + audio_info = file["torchaudio.info"] + audio_infos.append( + ( + audio_info.sample_rate, # sample rate + audio_info.num_frames, # number of frames + audio_info.num_channels, # number of channels + audio_info.bits_per_sample, # bits per sample + ) + ) + audio_encodings.append(audio_info.encoding) # encoding + + # annotated regions and duration + _annotated_duration = 0.0 + for segment in file["annotated"]: + # skip annotated regions that are shorter than training chunk duration + # if segment.duration < self.duration: + # continue + + # append annotated region + annotated_region = ( + file_id, + segment.duration, + segment.start, + ) + annotated_regions.append(annotated_region) + + # increment annotated duration + _annotated_duration += segment.duration + + # append annotated duration + annotated_duration.append(_annotated_duration) + + # annotations + for segment, _, label in file["annotation"].itertracks(yield_label=True): + # "scope" is provided by speaker diarization protocols to indicate + # whether speaker labels are local to the file ('file'), consistent across + # all files in a database ('database'), or globally consistent ('global') + + # 0 = 'file' / 1 = 'database' / 2 = 'global' + scope = Scopes.index(file["scope"]) + + # update list of file-scope labels + if label not in file_unique_labels: + file_unique_labels.append(label) + # and convert label to its (file-scope) index + file_label_idx = file_unique_labels.index(label) + + database_label_idx = global_label_idx = -1 + + if scope > 0: # 'database' or 'global' + # update list of database-scope labels + database = file["database"] + if database not in database_unique_labels: + database_unique_labels[database] = [] + if label not in database_unique_labels[database]: + database_unique_labels[database].append(label) + + # and convert label to its (database-scope) index + database_label_idx = database_unique_labels[database].index(label) + + if scope > 1: # 'global' + # update list of global-scope labels + if label not in unique_labels: + unique_labels.append(label) + # and convert label to its (global-scope) index + global_label_idx = unique_labels.index(label) + + annotations.append( + ( + file_id, # index of file + segment.start, # start time + segment.end, # end time + file_label_idx, # file-scope label index + database_label_idx, # database-scope label index + global_label_idx, # global-scope index + ) + ) + + # since not all metadata keys are present in all files, fallback to -1 when a key is missing + metadata = [ + tuple(metadatum.get(key, -1) for key in metadata_unique_values) + for metadatum in metadata + ] + metadata_dtype = [ + (key, get_dtype(max(m[i] for m in metadata))) + for i, key in enumerate(metadata_unique_values) + ] + + # turn list of files metadata into a single numpy array + # TODO: improve using https://github.com/pytorch/pytorch/issues/13246#issuecomment-617140519 + info_dtype = [ + ( + "sample_rate", + get_dtype(max(ai[0] for ai in audio_infos)), + ), + ( + "num_frames", + get_dtype(max(ai[1] for ai in audio_infos)), + ), + ("num_channels", "B"), + ("bits_per_sample", "B"), + ] + + # turn list of annotated regions into a single numpy array + region_dtype = [ + ( + "file_id", + get_dtype(max(ar[0] for ar in annotated_regions)), + ), + ("duration", "f"), + ("start", "f"), + ] + + # turn list of annotations into a single numpy array + segment_dtype = [ + ( + "file_id", + get_dtype(max(a[0] for a in annotations)), + ), + ("start", "f"), + ("end", "f"), + ("file_label_idx", get_dtype(max(a[3] for a in annotations))), + ("database_label_idx", get_dtype(max(a[4] for a in annotations))), + ("global_label_idx", get_dtype(max(a[5] for a in annotations))), + ] + + # save all protocol data in a dict + prepared_data = {} + + # keep track of protocol name + prepared_data["protocol"] = self.protocol.name + + prepared_data["audio-path"] = np.array(audios, dtype=np.str_) + audios.clear() + + prepared_data["audio-metadata"] = np.array(metadata, dtype=metadata_dtype) + metadata.clear() + + prepared_data["audio-info"] = np.array(audio_infos, dtype=info_dtype) + audio_infos.clear() + + prepared_data["audio-encoding"] = np.array(audio_encodings, dtype=np.str_) + audio_encodings.clear() + + prepared_data["audio-annotated"] = np.array(annotated_duration) + annotated_duration.clear() + + prepared_data["annotations-regions"] = np.array( + annotated_regions, dtype=region_dtype + ) + annotated_regions.clear() + + prepared_data["annotations-segments"] = np.array( + annotations, dtype=segment_dtype + ) + annotations.clear() + + prepared_data["metadata-values"] = metadata_unique_values + + for database, labels in database_unique_labels.items(): + prepared_data[f"metadata-{database}-labels"] = np.array( + labels, dtype=np.str_ + ) + database_unique_labels.clear() + + prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) + unique_labels.clear() + + if self.has_validation: + self.prepare_validation(prepared_data) + + self.post_prepare_data(prepared_data) + + # save prepared data on the disk + with open(self.cache, "wb") as cache_file: + np.savez_compressed(cache_file, **prepared_data) + def setup(self, stage="fit"): """Setup method @@ -251,7 +555,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): start_idx, end_idx, chunk_annotations[label_scope_key] ): mapped_label = mapping[label] - y[start : end + 1, mapped_label] = 1 + y[start:end + 1, mapped_label] = 1 sample["y"] = SlidingWindowFeature(y, self.model.receptive_field, labels=labels) @@ -666,6 +970,9 @@ def compute_embedding_loss(self, emb_prediction, target_emb, valid_embs): embeddings[valid_embs, :], targets[valid_embs] ) + if torch.any(valid_embs): + emb_loss = (1. / torch.sum(valid_embs)) * emb_loss + # skip batch if something went wrong for some reason if torch.isnan(emb_loss): return None @@ -731,7 +1038,9 @@ def training_step(self, batch, batch_idx: int): # an embedding is valid only if corresponding speaker is active in the diarization prediction and reference active_speaker_pred = torch.any(dia_multilabel > 0, dim=1) active_speaker_ref = torch.any(permutated_target_dia == 1, dim=1) - valid_embs = torch.logical_and(active_speaker_pred, active_speaker_ref)[num_remaining_dia_samples:] + valid_embs = torch.logical_and(active_speaker_pred, active_speaker_ref)[ + num_remaining_dia_samples: + ] permutated_target_powerset = self.model.powerset.to_powerset( permutated_target_dia.float() @@ -751,14 +1060,16 @@ def training_step(self, batch, batch_idx: int): emb_loss = torch.tensor(0) # if batch contains embedding subtask chunks, then compute embedding loss on these chunks: - if self.alpha != 1.0 and torch.any(keep[self.num_dia_samples :]): + if self.alpha != 1.0 and torch.any(valid_embs): emb_prediction = emb_prediction[num_remaining_dia_samples:] permutated_target_emb = permutated_target_emb[num_remaining_dia_samples:] emb_loss = self.compute_embedding_loss( emb_prediction, permutated_target_emb, valid_embs ) + loss = self.alpha * dia_loss + (1 - self.alpha) * emb_loss + else: + loss = self.alpha * dia_loss - loss = self.alpha * dia_loss + (1 - self.alpha) * emb_loss return {"loss": loss} # TODO: no need to compute gradient in this method From ad9e435ef0deb498e107ba2d3b90280ea24e4d1f Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 21 Jun 2024 14:57:47 +0200 Subject: [PATCH 76/83] wip: update joint model --- pyannote/audio/models/joint/__init__.py | 6 +- .../models/joint/end_to_end_diarization.py | 997 +++++++++++++----- 2 files changed, 708 insertions(+), 295 deletions(-) diff --git a/pyannote/audio/models/joint/__init__.py b/pyannote/audio/models/joint/__init__.py index f32ef8d98..97c1481d8 100644 --- a/pyannote/audio/models/joint/__init__.py +++ b/pyannote/audio/models/joint/__init__.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from .end_to_end_diarization import SpeakerEndToEndDiarization, SpeakerEndToEndDiarizationV2 +from .end_to_end_diarization import ( + WavLMEnd2EndDiarization, WavLMEnd2EndDiarizationv2, WavLMEnd2EndDiarizationv3 +) -__all__ = ["SpeakerEndToEndDiarization", "SpeakerEndToEndDiarizationV2"] +__all__ = ["WavLMEnd2EndDiarization", "WavLMEnd2EndDiarizationv2", "WavLMEnd2EndDiarizationv3"] diff --git a/pyannote/audio/models/joint/end_to_end_diarization.py b/pyannote/audio/models/joint/end_to_end_diarization.py index bfc0bec65..1be9abd05 100644 --- a/pyannote/audio/models/joint/end_to_end_diarization.py +++ b/pyannote/audio/models/joint/end_to_end_diarization.py @@ -9,8 +9,8 @@ # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, @@ -20,10 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os from functools import lru_cache from typing import Literal, Optional, Union -from warnings import warn import torch import torch.nn.functional as F @@ -34,79 +32,198 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task from pyannote.audio.models.blocks.pooling import StatsPool -from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.utils.params import merge_dict from pyannote.audio.utils.powerset import Powerset +from pyannote.audio.utils.receptive_field import ( + conv1d_num_frames, + conv1d_receptive_field_center, + conv1d_receptive_field_size, +) + +import torchaudio + # TODO deplace these two lines into uitls/multi_task Subtask = Literal["diarization", "embedding"] Subtasks = list(Subtask.__args__) -class WeSpeakerBasesEndToEndDiarization(Model): - """ - WeSpeaker-based joint speaker diarization and speaker - embedding extraction model +class WavLMEnd2EndDiarization(Model): + """Self-Supervised representation for joint speaker diarization + and speaker embeddings extraction + + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + freeze_wav2vec: bool, optional + Whether to freeze wa2vec. Default to true + emb_dim: int, optional + Dimension of the speaker embedding in output """ + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 4, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + } + + LINEAR_DEFAULT = {"hidden_size": 128, "num_layers": 2} + def __init__( self, - sincnet: dict = None, - lstm: dict = None, - linear: dict = None, - sample_rate=16000, - embedding_dim=256, - num_channels=1, - task: Optional[Union[Task, None]] = None, + sample_rate: int = 16000, + num_channels: int = 1, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + lstm: Optional[dict] = None, + linear: Optional[dict] = None, + embedding_dim: Optional[int] = 192, + task: Optional[Task] = None, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) - # speakers embedding extraction submodel: - self.resnet34 = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") - self.embedding_dim = embedding_dim - self.save_hyperparameters("embedding_dim") + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + # weighting parameters for the diarization branch + self.dia_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + # weighting parameters for the embedding branch + self.emb_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + self.save_hyperparameters("wav2vec", "wav2vec_layer", "freeze_wav2vec") - # speaker segmentation submodel: - self.pyannet = Model.from_pretrained( - "pyannote/segmentation-3.0", - use_auth_token=os.environ["HG_TOKEN"], + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULT, linear) + self.save_hyperparameters("lstm", "linear") + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), + **one_layer_lstm, + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + lstm_out_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + self.pooling = StatsPool(computde_std=False) + self.embeddings = nn.Linear(wav2vec_dim, embedding_dim) + + self.save_hyperparameters("embedding_dim") @property def dimension(self) -> int: """Dimension of output""" - if isinstance(self.specifications, tuple): - raise ValueError("PyanNet does not support multi-tasking.") - - if self.specifications.powerset: - return self.specifications.num_powerset_classes - else: - return len(self.specifications.classes) - - def build(self): - """""" - dia_specs = self.specifications[Subtasks.index("diarization")] - self.powerset = Powerset( - len(dia_specs.classes), - dia_specs.powerset_max_classes, - ) + return self.specifications[Subtasks.index("diarization")].num_powerset_classes @lru_cache def num_frames(self, num_samples: int) -> int: - """Compute number of output frames for a given number of input samples + """Compute number of output frames Parameters ---------- num_samples : int - Number of input samples + Number of input samples. Returns ------- num_frames : int - Number of output frames + Number of output frames. """ - return self.pyannet.sincnet.num_frames(num_samples) + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames def receptive_field_size(self, num_frames: int = 1) -> int: """Compute size of receptive field @@ -121,7 +238,16 @@ def receptive_field_size(self, num_frames: int = 1) -> int: receptive_field_size : int Receptive field size. """ - return self.pyannet.sincnet.receptive_field_size(num_frames=num_frames) + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size def receptive_field_center(self, frame: int = 0) -> int: """Compute center of receptive field @@ -136,134 +262,218 @@ def receptive_field_center(self, frame: int = 0) -> int: receptive_field_center : int Index of receptive field center. """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center - return self.pyannet.sincnet.receptive_field_center(frame=frame) + def build(self): + """""" + max_num_speaker_per_chunk = len(self.specifications[Subtasks.index("diarization")].classes) + max_num_speaker_per_frame = self.specifications[Subtasks.index("diarization")].powerset_max_classes + self.powerset = Powerset( + max_num_speaker_per_chunk, + max_num_speaker_per_frame + ) - def forward(self, waveformms: torch.Tensor) -> torch.Tensor: - """ + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + lstm = self.hparams.lstm + in_features = lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1) + + self.classifier = nn.Linear(in_features, self.dimension) + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward Parameters ---------- - waveforms : torch.Tensor - Batch of waveforms with shape (batch, channel, sample) + waveforms : (batch, channel, sample) + + Returns + ------- + diarization, embeddings : (batch, frames, classes), (batch, num_speaker, embed_dim) """ - dia_outputs = self.pyannet(waveformms) + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.hparams.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + else: + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + dia_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.dia_wav2vec_weights, dim=0 + ) + emb_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.emb_wav2vec_weights, dim=0 + ) + else: + dia_outputs = emb_outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm(dia_outputs) + else: + for i, lstm in enumerate(self.lstm): + dia_outputs, _ = lstm(dia_outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + dia_outputs = self.dropout(dia_outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + dia_outputs = F.leaky_relu(linear(dia_outputs)) + dia_outputs = self.classifier(dia_outputs) + dia_outputs = F.log_softmax(dia_outputs, dim=-1) + weights = self.powerset.to_multilabel(dia_outputs, soft=True) weights = rearrange(weights, "b f s -> b s f") - emb_outputs = self.resnet34(waveformms, weights) + emb_outputs = rearrange(emb_outputs, "b f w -> b w f") + emb_outputs = self.pooling(emb_outputs, weights) + emb_outputs = self.embeddings(emb_outputs) + return (dia_outputs, emb_outputs) -class SpeakerEndToEndDiarization(Model): - """Speaker End-to-End Diarization and Embedding model - SINCNET -- TDNN .. TDNN -- TDNN ..TDNN -- StatsPool -- Linear -- Classifier - \\ LSTM ... LSTM -- FeedForward -- Classifier +class WavLMEnd2EndDiarizationv2(Model): + """Self-Supervised representation for joint speaker diarization + and speaker embeddings extraction + + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + freeze_wav2vec: bool, optional + Whether to freeze wa2vec. Default to true + emb_dim: int, optional + Dimension of the speaker embedding in output """ - SINCNET_DEFAULTS = {"stride": 10} + WAV2VEC_DEFAULTS = "WAVLM_BASE" + LSTM_DEFAULTS = { "hidden_size": 128, - "num_layers": 2, + "num_layers": 4, "bidirectional": True, "monolithic": True, "dropout": 0.0, - "batch_first": True, } - LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + + LINEAR_DEFAULT = {"hidden_size": 128, "num_layers": 2} def __init__( self, - sincnet: dict = None, - lstm: dict = None, - linear: dict = None, sample_rate: int = 16000, num_channels: int = 1, - num_features: int = 60, - embedding_dim: int = 512, - separation_idx: int = 2, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + lstm: Optional[dict] = None, + linear: Optional[dict] = None, + embedding_dim: Optional[int] = 192, task: Optional[Task] = None, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) - if num_features != 60: - warn( - "For now, the model only support a number of features of 60. Set it to 60" + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + # weighting parameters for the diarization branch + self.dia_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True ) - num_features = 60 - self.num_features = num_features - self.separation_idx = separation_idx - self.save_hyperparameters("num_features", "embedding_dim", "separation_idx") - - # sincnet module - sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) - sincnet["sample_rate"] = sample_rate - self.sincnet = SincNet(**sincnet) - self.save_hyperparameters("sincnet") - - # tdnn modules - self.tdnn_blocks = nn.ModuleList() - in_channel = num_features - out_channels = [512, 512, 512, 512, 1500] - kernel_sizes = [5, 3, 3, 1, 1] - dilations = [1, 2, 3, 1, 1] - - for out_channel, kernel_size, dilation in zip( - out_channels, kernel_sizes, dilations - ): - self.tdnn_blocks.extend( - [ - nn.Sequential( - nn.Conv1d( - in_channels=in_channel, - out_channels=out_channel, - kernel_size=kernel_size, - dilation=dilation, - padding="same", - ), - nn.LeakyReLU(), - nn.BatchNorm1d(out_channel), - ), - ] + # weighting parameters for the embedding branch + self.emb_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True ) - in_channel = out_channel + self.save_hyperparameters("wav2vec", "wav2vec_layer", "freeze_wav2vec") - # lstm modules: lstm = merge_dict(self.LSTM_DEFAULTS, lstm) - self.save_hyperparameters("lstm") + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULT, linear) + self.save_hyperparameters("lstm", "linear") monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(out_channels[separation_idx], **multi_layer_lstm) + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) else: num_layers = lstm["num_layers"] if num_layers > 1: self.dropout = nn.Dropout(p=lstm["dropout"]) one_layer_lstm = dict(lstm) - del one_layer_lstm["monolithic"] one_layer_lstm["num_layers"] = 1 one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] self.lstm = nn.ModuleList( [ nn.LSTM( - out_channels[separation_idx] - if i == 0 - else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), - **one_layer_lstm + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), + **one_layer_lstm, ) for i in range(num_layers) ] ) - # linear module for the diarization part: - linear = merge_dict(self.LINEAR_DEFAULTS, linear) - self.save_hyperparameters("linear") if linear["num_layers"] < 1: return - - lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + lstm_out_features = self.hparams.lstm["hidden_size"] * ( 2 if self.hparams.lstm["bidirectional"] else 1 ) self.linear = nn.ModuleList( @@ -279,185 +489,294 @@ def __init__( ] ) - # stats pooling module for the embedding part: - self.stats_pool = StatsPool() - # linear module for the embedding part: - self.embedding = nn.Linear(in_channel * 2, embedding_dim) + self.pooling = StatsPool(computde_std=False) + self.embeddings = nn.Sequential( + nn.Linear(in_features=wav2vec_dim, out_features=1024), + nn.LeakyReLU(), + nn.Linear(in_features=1024, out_features=embedding_dim), + ) + + self.save_hyperparameters("embedding_dim") + + @property + def dimension(self) -> int: + """Dimension of output""" + return self.specifications[Subtasks.index("diarization")].num_powerset_classes + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center def build(self): + """""" + max_num_speaker_per_chunk = len(self.specifications[Subtasks.index("diarization")].classes) + max_num_speaker_per_frame = self.specifications[Subtasks.index("diarization")].powerset_max_classes + self.powerset = Powerset( + max_num_speaker_per_chunk, + max_num_speaker_per_frame + ) + if self.hparams.linear["num_layers"] > 0: in_features = self.hparams.linear["hidden_size"] else: - in_features = self.hparams.lstm["hidden_size"] * ( - 2 if self.hparams.lstm["bidirectional"] else 1 - ) + lstm = self.hparams.lstm + in_features = lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1) - diarization_spec = self.specifications[Subtasks.index("diarization")] - out_features = diarization_spec.num_powerset_classes - self.classifier = nn.Linear(in_features, out_features) - self.powerset = Powerset( - len(diarization_spec.classes), - diarization_spec.powerset_max_classes, - ) + self.classifier = nn.Linear(in_features, self.dimension) - def forward( - self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward Parameters ---------- - waveforms : torch.Tensor - Batch of waveforms with shape (batch, channel, sample) - weights : torch.Tensor, optional - Batch of weights wiht shape (batch, frame) + waveforms : (batch, channel, sample) + + Returns + ------- + diarization, embeddings : (batch, frames, classes), (batch, num_speaker, embed_dim) """ - common_outputs = self.sincnet(waveforms) - # (batch, features, frames) - # common part to diarization and embedding: - tdnn_idx = 0 - while tdnn_idx <= self.separation_idx: - common_outputs = self.tdnn_blocks[tdnn_idx](common_outputs) - tdnn_idx = tdnn_idx + 1 - # diarization part: - if self.hparams.lstm["monolithic"]: - diarization_outputs, _ = self.lstm( - rearrange(common_outputs, "batch feature frame -> batch frame feature") - ) + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.hparams.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) else: - diarization_outputs = rearrange( - common_outputs, "batch feature frame -> batch frame feature" + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + dia_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.dia_wav2vec_weights, dim=0 + ) + emb_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.emb_wav2vec_weights, dim=0 ) + else: + dia_outputs = emb_outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm(dia_outputs) + else: for i, lstm in enumerate(self.lstm): - diarization_outputs, _ = lstm(diarization_outputs) + dia_outputs, _ = lstm(dia_outputs) if i + 1 < self.hparams.lstm["num_layers"]: - diarization_outputs = self.linear() + dia_outputs = self.dropout(dia_outputs) if self.hparams.linear["num_layers"] > 0: for linear in self.linear: - diarization_outputs = F.leaky_relu(linear(diarization_outputs)) - diarization_outputs = self.classifier(diarization_outputs) - diarization_outputs = F.log_softmax(diarization_outputs, dim=-1) - weights = self.powerset(diarization_outputs).transpose(1, 2) + dia_outputs = F.leaky_relu(linear(dia_outputs)) + dia_outputs = self.classifier(dia_outputs) + dia_outputs = F.log_softmax(dia_outputs, dim=-1) - # embedding part: - embedding_outputs = common_outputs - for tdnn_block in self.tdnn_blocks[tdnn_idx:]: - embedding_outputs = tdnn_block(embedding_outputs) - embedding_outputs = self.stats_pool(embedding_outputs, weights=weights) - embedding_outputs = self.embedding(embedding_outputs) + weights = self.powerset.to_multilabel(dia_outputs, soft=True) + weights = rearrange(weights, "b f s -> b s f") + emb_outputs = rearrange(emb_outputs, "b f w -> b w f") + emb_outputs = self.pooling(emb_outputs, weights) + emb_outputs = self.embeddings(emb_outputs) - return (diarization_outputs, embedding_outputs) + return (dia_outputs, emb_outputs) -class SpeakerEndToEndDiarizationV2(Model): - """This version uses a LSTM encoder in the embedding branch instead StatsPool block""" +class WavLMEnd2EndDiarizationv3(Model): + """With modified weights + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + freeze_wav2vec: bool, optional + Whether to freeze wa2vec. Default to true + emb_dim: int, optional + Dimension of the speaker embedding in output + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" - SINCNET_DEFAULTS = {"stride": 10} LSTM_DEFAULTS = { "hidden_size": 128, - "num_layers": 2, + "num_layers": 4, "bidirectional": True, "monolithic": True, "dropout": 0.0, - "batch_first": True, } - LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + + LINEAR_DEFAULT = {"hidden_size": 128, "num_layers": 2} def __init__( self, - sincnet: dict = None, - lstm: dict = None, - linear: dict = None, sample_rate: int = 16000, num_channels: int = 1, - num_features: int = 60, - embedding_dim: int = 512, - separation_idx: int = 2, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + lstm: Optional[dict] = None, + linear: Optional[dict] = None, + embedding_dim: Optional[int] = 192, task: Optional[Task] = None, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) - if num_features != 60: - warn( - "For now, the model only support a number of features of 60. Set it to 60" + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + # weighting parameters for the diarization branch + self.dia_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True ) - num_features = 60 - self.num_features = num_features - self.separation_idx = separation_idx - self.save_hyperparameters("num_features", "embedding_dim", "separation_idx") - - # sincnet module - sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) - sincnet["sample_rate"] = sample_rate - self.sincnet = SincNet(**sincnet) - self.save_hyperparameters("sincnet") - - # tdnn modules - self.tdnn_blocks = nn.ModuleList() - in_channel = num_features - out_channels = [512, 512, 512, 512, 1500] - kernel_sizes = [5, 3, 3, 1, 1] - dilations = [1, 2, 3, 1, 1] - self.last_tdnn_out_channels = out_channels[-1] - - for out_channel, kernel_size, dilation in zip( - out_channels, kernel_sizes, dilations - ): - self.tdnn_blocks.extend( - [ - nn.Sequential( - nn.Conv1d( - in_channels=in_channel, - out_channels=out_channel, - kernel_size=kernel_size, - dilation=dilation, - padding="same", - ), - nn.LeakyReLU(), - nn.BatchNorm1d(out_channel), - ), - ] + # weighting parameters for the embedding branch + self.emb_wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True ) - in_channel = out_channel + self.save_hyperparameters("wav2vec", "wav2vec_layer", "freeze_wav2vec") - # lstm modules: lstm = merge_dict(self.LSTM_DEFAULTS, lstm) - self.save_hyperparameters("lstm") + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULT, linear) + self.save_hyperparameters("lstm", "linear") monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(out_channels[separation_idx], **multi_layer_lstm) + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) else: num_layers = lstm["num_layers"] if num_layers > 1: self.dropout = nn.Dropout(p=lstm["dropout"]) one_layer_lstm = dict(lstm) - del one_layer_lstm["monolithic"] one_layer_lstm["num_layers"] = 1 one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] self.lstm = nn.ModuleList( [ nn.LSTM( - out_channels[separation_idx] - if i == 0 - else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), - **one_layer_lstm + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), + **one_layer_lstm, ) for i in range(num_layers) ] ) - # linear module for the diarization part: - linear = merge_dict(self.LINEAR_DEFAULTS, linear) - self.save_hyperparameters("linear") if linear["num_layers"] < 1: return - - lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + lstm_out_features = self.hparams.lstm["hidden_size"] * ( 2 if self.hparams.lstm["bidirectional"] else 1 ) self.linear = nn.ModuleList( @@ -473,91 +792,183 @@ def __init__( ] ) - # linear module for the embedding part: - self.embedding = nn.Linear(2 * self.last_tdnn_out_channels, embedding_dim) + self.pooling = StatsPool(computde_std=False) + self.embeddings = nn.Sequential( + nn.Linear(in_features=wav2vec_dim, out_features=1024), + nn.LeakyReLU(), + nn.Linear(in_features=1024, out_features=embedding_dim), + ) - def build(self): - if self.hparams.linear["num_layers"] > 0: - in_features = self.hparams.linear["hidden_size"] - else: - in_features = self.hparams.lstm["hidden_size"] * ( - 2 if self.hparams.lstm["bidirectional"] else 1 + self.save_hyperparameters("embedding_dim") + + @property + def dimension(self) -> int: + """Dimension of output""" + return self.specifications[Subtasks.index("diarization")].num_powerset_classes + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], ) - diarization_spec = self.specifications[Subtasks.index("diarization")] - out_features = diarization_spec.num_powerset_classes - self.classifier = nn.Linear(in_features, out_features) + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + def build(self): + """""" + max_num_speaker_per_chunk = len(self.specifications[Subtasks.index("diarization")].classes) + max_num_speaker_per_frame = self.specifications[Subtasks.index("diarization")].powerset_max_classes self.powerset = Powerset( - len(diarization_spec.classes), - diarization_spec.powerset_max_classes, + max_num_speaker_per_chunk, + max_num_speaker_per_frame ) - lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( - 2 if self.hparams.lstm["bidirectional"] else 1 - ) - self.encoder = nn.LSTM( - # number of channel in the outputs of the last TDNN layer + lstm_out_features - input_size=self.last_tdnn_out_channels + lstm_out_features, - hidden_size=len(diarization_spec.classes) * self.last_tdnn_out_channels, - batch_first=True, - bidirectional=True, - ) + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + lstm = self.hparams.lstm + in_features = lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1) - def forward( - self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ + self.classifier = nn.Linear(in_features, self.dimension) + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward Parameters ---------- - waveforms : torch.Tensor - Batch of waveforms with shape (batch, channel, sample) - weights : torch.Tensor, optional - Batch of weights wiht shape (batch, frame) + waveforms : (batch, channel, sample) + + Returns + ------- + diarization, embeddings : (batch, frames, classes), (batch, num_speaker, embed_dim) """ - common_outputs = self.sincnet(waveforms) - # (batch, features, frames) - # common part to diarization and embedding: - tdnn_idx = 0 - while tdnn_idx <= self.separation_idx: - common_outputs = self.tdnn_blocks[tdnn_idx](common_outputs) - tdnn_idx = tdnn_idx + 1 - # diarization part: - dia_outputs = common_outputs - if self.hparams.lstm["monolithic"]: - dia_outputs, _ = self.lstm( - rearrange(dia_outputs, "batch feature frame -> batch frame feature") - ) + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.hparams.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) else: - dia_outputs = rearrange( - common_outputs, "batch feature frame -> batch frame feature" + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + dia_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.dia_wav2vec_weights, dim=0 + ) + emb_outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.emb_wav2vec_weights, dim=0 ) + else: + dia_outputs = emb_outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + dia_outputs, _ = self.lstm(dia_outputs) + else: for i, lstm in enumerate(self.lstm): dia_outputs, _ = lstm(dia_outputs) if i + 1 < self.hparams.lstm["num_layers"]: - dia_outputs = self.linear() - lstm_outputs = dia_outputs + dia_outputs = self.dropout(dia_outputs) + if self.hparams.linear["num_layers"] > 0: for linear in self.linear: dia_outputs = F.leaky_relu(linear(dia_outputs)) dia_outputs = self.classifier(dia_outputs) dia_outputs = F.log_softmax(dia_outputs, dim=-1) - # embedding part: - emb_outputs = common_outputs - for tdnn_block in self.tdnn_blocks[tdnn_idx:]: - emb_outputs = tdnn_block(emb_outputs) - - emb_outputs = rearrange(emb_outputs, "b c f -> b f c") - # Concatenation of last tdnn layer outputs with the last diarization lstm outputs: - emb_outputs = torch.cat((emb_outputs, lstm_outputs), dim=2) - _, emb_outputs = self.encoder(emb_outputs) - emb_outputs = rearrange(emb_outputs[0], "l b h -> b (l h)") - emb_outputs = torch.reshape( - emb_outputs, (emb_outputs.shape[0], self.powerset.num_classes, -1) - ) - emb_outputs = self.embedding(emb_outputs) + # hard-segmentation in multilabel space + multilabel_segmentations: torch.Tensor = self.powerset.to_multilabel(dia_outputs) + # (batch_size, num_frames, max_speakers_per_chunk), {0, 1} + + weights = ( + ( + F.one_hot( + torch.argmax(dia_outputs, dim=2), + num_classes=self.powerset.num_powerset_classes, + )[:, :, 1 : 1 + self.powerset.num_classes] + + 1e-2 + ) + * multilabel_segmentations + ).transpose(2, 1) + # (batch_size, max_speakers_per_chunk, num_frames) + # 0.000 if speaker is inactive + # 0.001 if speaker is active but not alone + # 1.001 if speaker is active and alone + + emb_outputs = rearrange(emb_outputs, "b f w -> b w f") + emb_outputs = self.pooling(emb_outputs, weights) + emb_outputs = self.embeddings(emb_outputs) return (dia_outputs, emb_outputs) From 8608a1c169d7e839d619a90acdb2065e4e3b5533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 8 Jul 2024 21:47:30 +0200 Subject: [PATCH 77/83] wip: add pipeline working with joint model --- .../audio/pipelines/speaker_diarization.py | 377 +++++++++++++++++- 1 file changed, 375 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index edfa5966c..a46425b64 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -27,7 +27,7 @@ import math import textwrap import warnings -from typing import Callable, Optional, Text, Union +from typing import Callable, Optional, Text, Tuple, Union import numpy as np import torch @@ -35,10 +35,12 @@ from pyannote.core import Annotation, SlidingWindowFeature from pyannote.metrics.diarization import GreedyDiarizationErrorRate from pyannote.pipeline.parameter import ParamDict, Uniform +from sklearn.cluster import AgglomerativeClustering from pyannote.audio import Audio, Inference, Model, Pipeline from pyannote.audio.core.io import AudioFile -from pyannote.audio.pipelines.clustering import Clustering +from pyannote.audio.core.task import Problem, Resolution +from pyannote.audio.pipelines.clustering import AgglomerativeClustering, Clustering from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from pyannote.audio.pipelines.utils import ( PipelineModel, @@ -647,3 +649,374 @@ def apply( def get_metric(self) -> GreedyDiarizationErrorRate: return GreedyDiarizationErrorRate(**self.der_variant) + + +class SpeakerDiarizationV2(SpeakerDiarizationMixin, Pipeline): + """Speaker diarization pipeline with joint segmentation + embedding model + + Parameters + ---------- + model : Model, str, or dict, optional + Pretrained (segmentation + embedding) model. + See pyannote.audio.pipelines.utils.get_model for supported format. + step: float, optional + The model is applied on a window sliding over the whole audio file. + `step` controls the step of this window, provided as a ratio of its + duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). + clustering : str, optional + Clustering algorithm. See pyannote.audio.pipelines.clustering.Clustering + for available options. Defaults to "AgglomerativeClustering". + batch_size : int, optional + Batch size used for inference. Defaults to 1. + use_auth_token : str, optional + When loading private huggingface.co models, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + + Usage + ----- + # perform (unconstrained) diarization + >>> diarization = pipeline("/path/to/audio.wav") + + # perform diarization, targetting exactly 4 speakers + >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) + + # perform diarization, with at least 2 speakers and at most 10 speakers + >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + + # perform diarization and get one representative embedding per speaker + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) + >>> for s, speaker in enumerate(diarization.labels()): + ... # embeddings[s] is the embedding of speaker `speaker` + + """ + + def __init__( + self, + model: PipelineModel = None, + step: float = 0.1, + clustering: str = "AgglomerativeClustering", + batch_size: int = 1, + use_auth_token: Union[Text, None] = None, + ): + super().__init__() + + self.model = model + model: Model = get_model(model, use_auth_token=use_auth_token) + + assert len(model.specifications) == 2 + segmentation_specifications, embedding_specifications = model.specifications + # TODO: check that specs are correct + assert segmentation_specifications.problem == Problem.MONO_LABEL_CLASSIFICATION + assert segmentation_specifications.resolution == Resolution.FRAME + assert embedding_specifications.problem == Problem.REPRESENTATION + assert embedding_specifications.resolution == Resolution.CHUNK + + self.step = step + self.klustering = clustering + + duration: float = segmentation_specifications.duration + self._inference = Inference( + model, + duration=duration, + step=self.step * duration, + skip_aggregation=True, + skip_conversion=False, # <-- output multilabel segmentation + batch_size=batch_size, + ) + + self.clustering = AgglomerativeClustering(metric="cosine") + + @property + def batch_size(self) -> int: + return self._inference.batch_size + + @batch_size.setter + def batch_size(self, batch_size: int): + self._inference.batch_size = batch_size + + def default_parameters(self): + raise NotImplementedError() + + def classes(self): + speaker = 0 + while True: + yield f"SPEAKER_{speaker:02d}" + speaker += 1 + + @property + def CACHED_INFERENCE(self): + return "training_cache/inference" + + def get_inference(self, file, hook=None) -> Tuple[SlidingWindowFeature]: + """Apply joint model + + Parameter + --------- + file : AudioFile + hook : Optional[Callable] + + Returns + ------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + embeddings : (num_chunks, num_speakers, dimension) SlidingWindowFeature + """ + + if hook is not None: + hook = functools.partial(hook, "inference", None) + + if self.training: + if self.CACHED_INFERENCE in file: + inference = file[self.CACHED_INFERENCE] + else: + inference = self._inference(file, hook=hook) + file[self.CACHED_INFERENCE] = inference + else: + inference = self._inference(file, hook=hook) + + return inference + + def reconstruct( + self, + segmentations: SlidingWindowFeature, + hard_clusters: np.ndarray, + count: SlidingWindowFeature, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, local_num_speakers = segmentations.data.shape + + num_clusters = np.max(hard_clusters) + 1 + clustered_segmentations = np.nan * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(hard_clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + + return self.to_diarization(clustered_segmentations, count) + + def apply( + self, + file: AudioFile, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + return_embeddings: bool = False, + hook: Optional[Callable] = None, + ) -> Union[Annotation, Tuple[Annotation, np.ndarray]]: + """Apply speaker diarization + + Parameters + ---------- + file : AudioFile + Processed file. + num_speakers : int, optional + Number of speakers, when known. + min_speakers : int, optional + Minimum number of speakers. Has no effect when `num_speakers` is provided. + max_speakers : int, optional + Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. + hook : callable, optional + Callback called after each major steps of the pipeline as follows: + hook(step_name, # human-readable name of current step + step_artefact, # artifact generated by current step + file=file) # file being processed + Time-consuming steps call `hook` multiple times with the same `step_name` + and additional `completed` and `total` keyword arguments usable to track + progress of current step. + + Returns + ------- + diarization : Annotation + Speaker diarization + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + inference = self.get_inference(file, hook=hook) + hook("inference", inference) + binarized_segmentations, embeddings = inference + # shape: (num_chunks, num_frames, local_num_speakers) + num_chunks, num_frames, local_num_speakers = binarized_segmentations.data.shape + _, _, dimension = embeddings.data.shape + + # estimate frame-level number of instantaneous speakers + count = self.speaker_count( + binarized_segmentations, + self._inference.model.receptive_field, + warm_up=(0.0, 0.0), + ) + hook("speaker_counting", count) + # shape: (num_frames, 1) + # dtype: int + + # exit early when no speaker is ever active + if np.nanmax(count.data) == 0.0: + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, np.zeros((0, dimension)) + + return diarization + + hard_clusters, _, centroids = self.clustering( + embeddings=embeddings.data, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + min_clusters=min_speakers, + max_clusters=max_speakers, + file=file, # <== for oracle clustering + frames=self._inference.model.receptive_field, # <== for oracle clustering + ) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) + + # number of detected clusters is the number of different speakers + num_different_speakers = np.max(hard_clusters) + 1 + + # detected number of speakers can still be out of bounds + # (specifically, lower than `min_speakers`), since there could be too few embeddings + # to make enough clusters with a given minimum cluster size. + if ( + num_different_speakers < min_speakers + or num_different_speakers > max_speakers + ): + warnings.warn( + textwrap.dedent( + f""" + The detected number of speakers ({num_different_speakers}) is outside + the given bounds [{min_speakers}, {max_speakers}]. This can happen if the + given audio file is too short to contain {min_speakers} or more speakers. + Try to lower the desired minimal number of speakers. + """ + ) + ) + + # during counting, we could possibly overcount the number of instantaneous + # speakers due to segmentation errors, so we cap the maximum instantaneous number + # of speakers by the `max_speakers` value + count.data = np.minimum(count.data, max_speakers).astype(np.int8) + + # reconstruct discrete diarization from raw hard clusters + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + + hard_clusters[inactive_speakers] = -2 + discrete_diarization = self.reconstruct( + binarized_segmentations, + hard_clusters, + count, + ) + hook("discrete_diarization", discrete_diarization) + + # convert to continuous diarization + diarization = self.to_annotation( + discrete_diarization, + min_duration_on=0.0, + min_duration_off=0.0, + ) + diarization.uri = file["uri"] + + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + + if "annotation" in file and file["annotation"]: + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} + + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { + label: expected_label + for label, expected_label in zip(diarization.labels(), self.classes()) + } + + diarization = diarization.rename_labels(mapping=mapping) + + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + + if not return_embeddings: + return diarization + + # this can happen when we use OracleClustering + if centroids is None: + return diarization, None + + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad( + centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)) + ) + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + return diarization, centroids + + def get_metric(self) -> GreedyDiarizationErrorRate: + return GreedyDiarizationErrorRate(**self.der_variant) From b91df8cc6406a64876b6bd7b7d30e2280785ef49 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 25 Oct 2024 13:46:25 +0200 Subject: [PATCH 78/83] wip: add validation pipeline --- .../speaker_diarization_and_embedding.py | 411 +++++++++++++----- 1 file changed, 299 insertions(+), 112 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 76833c997..866df830e 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -22,27 +22,31 @@ from collections import defaultdict import itertools -import math from pathlib import Path import random import warnings from tempfile import mkstemp -from typing import Dict, Literal, Optional, Sequence, Union +from typing import Dict, Literal, Optional, Sequence, Tuple, Union import numpy as np import torch from einops import rearrange from matplotlib import pyplot as plt -from pyannote.core import Segment, SlidingWindowFeature +from pyannote.core import ( + Annotation, + Segment, + SlidingWindow, + SlidingWindowFeature, + Timeline, +) from pyannote.database.protocol.protocol import Scope, Subset -from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger from pytorch_metric_learning.losses import ArcFaceLoss from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import ( - Problem, Resolution, Specifications, get_dtype -) +from scipy.spatial.distance import cdist + +from pyannote.audio.core.task import Problem, Resolution, Specifications, get_dtype from pyannote.audio.tasks import SpeakerDiarization from pyannote.audio.torchmetrics import ( DiarizationErrorRate, @@ -53,6 +57,13 @@ from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.random import create_rng_for_worker +from pyannote.audio.pipelines.clustering import KMeansClustering, OracleClustering +from pyannote.audio.pipelines.utils import SpeakerDiarizationMixin +from pyannote.audio.core.io import Audio + +from pyannote.metrics.diarization import ( + DiarizationErrorRate as GlobalDiarizationErrorRate, +) Subtask = Literal["diarization", "embedding"] @@ -260,7 +271,7 @@ def prepare_data(self): for segment in file["annotated"]: # skip annotated regions that are shorter than training chunk duration # if segment.duration < self.duration: - # continue + # continue # append annotated region annotated_region = ( @@ -421,6 +432,13 @@ def prepare_data(self): with open(self.cache, "wb") as cache_file: np.savez_compressed(cache_file, **prepared_data) + def prepare_validation(self, prepared_data: Dict) -> None: + """Each validation batch correspond to a part of a validation file""" + validation_mask = prepared_data["audio-metadata"]["subset"] == Subsets.index( + "development" + ) + prepared_data["validation-files"] = np.argwhere(validation_mask).reshape((-1,)) + def setup(self, stage="fit"): """Setup method @@ -555,7 +573,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): start_idx, end_idx, chunk_annotations[label_scope_key] ): mapped_label = mapping[label] - y[start:end + 1, mapped_label] = 1 + y[start : end + 1, mapped_label] = 1 sample["y"] = SlidingWindowFeature(y, self.model.receptive_field, labels=labels) @@ -762,6 +780,39 @@ def train__iter__(self): # generate random chunk yield next(chunks) + def val__getitem__(self, idx) -> Dict: + """Validation items are generated so that all samples in a batch come from the same + validation file. These samples are created by sliding a window over the first seconds of + the validation file, with a step (for now arbitrally) set to 0.2 (20% of the task duration, + e.g. 1 second for a duration of 5 seconds)""" + + file_idx = idx // self.batch_size + chunk_idx = idx % self.batch_size + + file_id = self.prepared_data["validation-files"][file_idx] + file = next( + itertools.islice(self.protocol.development(), file_idx, file_idx + 1) + ) + + file_duration = file.get( + "duration", Audio("downmix").get_duration(file["audio"]) + ) + start_time = chunk_idx * ( + (file_duration - self.duration) / (self.batch_size - 1) + ) + + chunk = self.prepare_chunk(file_id, start_time, self.duration) + + if chunk_idx == 0: + chunk["annotation"] = file["annotation"] + + chunk["start_time"] = start_time + + return chunk + + def val__len__(self): + return len(self.prepared_data["validation-files"]) * self.batch_size + def collate_y(self, batch) -> torch.Tensor: """ Parameters @@ -859,14 +910,19 @@ def collate_fn(self, batch, stage="train"): sample_rate=self.model.hparams.sample_rate, targets=collated_y_dia.unsqueeze(1), ) - - return { + collated_batch = { "X": augmented.samples, "y_dia": augmented.targets.squeeze(1), "y_emb": collate_y_emb, "meta": collated_meta, } + if stage == "val": + collated_batch["annotation"] = batch[0]["annotation"] + collated_batch["start_times"] = [b["start_time"] for b in batch] + + return collated_batch + def setup_loss_func(self): self.model.arc_face_loss = ArcFaceLoss( len(self.specifications[Subtasks.index("embedding")].classes), @@ -971,7 +1027,7 @@ def compute_embedding_loss(self, emb_prediction, target_emb, valid_embs): ) if torch.any(valid_embs): - emb_loss = (1. / torch.sum(valid_embs)) * emb_loss + emb_loss = (1.0 / torch.sum(valid_embs)) * emb_loss # skip batch if something went wrong for some reason if torch.isnan(emb_loss): @@ -1072,6 +1128,152 @@ def training_step(self, batch, batch_idx: int): return {"loss": loss} + def reconstruct( + self, + segmentations: SlidingWindowFeature, + clusters: np.ndarray, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, _ = segmentations.data.shape + num_clusters = np.max(clusters) + 1 + clustered_segmentations = np.nan * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + return clustered_segmentations + + def aggregate(self, segmentations: SlidingWindowFeature, pad_duration:float) -> SlidingWindowFeature: + num_chunks, num_frames, num_speakers = segmentations.data.shape + sliding_window = segmentations.sliding_window + frame_duration = sliding_window.duration / num_frames + + if num_chunks <= 1: + return segmentations[0] + + num_padding_frames = np.round( + pad_duration / frame_duration + ).astype(np.uint32) + aggregated_segmentation = segmentations[0] + + for chunk_segmentation in segmentations[1:]: + padding = np.zeros((num_padding_frames, num_speakers)) + aggregated_segmentation = np.concatenate( + (aggregated_segmentation, padding, chunk_segmentation) + ) + return SlidingWindowFeature(aggregated_segmentation.astype(np.int8), SlidingWindow(step=frame_duration, duration=frame_duration)) + + def to_diarization( + self, + segmentations: SlidingWindowFeature, + pad_duration: float = 0., + ) -> SlidingWindowFeature: + """Build diarization out of preprocessed segmentation and precomputed speaker count + + Parameters + ---------- + segmentations : SlidingWindowFeature + (num_chunks, num_frames, num_speakers)-shaped segmentations + count : SlidingWindow_feature + (num_frames, 1)-shaped speaker count + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + activations = self.aggregate(segmentations, pad_duration=pad_duration) + # shape: (num_frames, num_speakers) + _, num_speakers = activations.data.shape + + count = np.sum(activations, axis=1, keepdims=True) + # shape: (num_frames, 1) + + max_speakers_per_frame = np.max(count.data) + if num_speakers < max_speakers_per_frame: + activations.data = np.pad( + activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers)) + ) + + extent = activations.extent & count.extent + activations = activations.crop(extent, return_data=False) + count = count.crop(extent, return_data=False) + + sorted_speakers = np.argsort(-activations, axis=-1) + binary = np.zeros_like(activations.data) + + for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)): + for i in range(c.item()): + binary[t, speakers[i]] = 1.0 + + return SlidingWindowFeature(binary, activations.sliding_window) + + def compute_metric( + self, + reference: Annotation, + hypothesis: Tuple[SlidingWindowFeature, np.ndarray], + pad_duration: float, + ): + """Compute diarization annotation from binarized segmentation and cluster (num_chunk, num_speaker)""" + frames = self.model.receptive_field + binarized_segmentations, clusters = hypothesis + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + clusters[inactive_speakers] = -2 + + clustered_segmentations = self.reconstruct( + binarized_segmentations, clusters + ) + + binarized_diarization = self.to_diarization(clustered_segmentations, pad_duration=pad_duration) + diarization = SpeakerDiarizationMixin.to_annotation(binarized_diarization) + + metric = GlobalDiarizationErrorRate() + metric(reference, diarization, detailed=True) + + result = metric[:] + metric_dict = {"der": 0.} + for component in ["false alarm", "missed detection", "confusion"]: + metric_dict[component] = (result[component] / result["total"]) + metric_dict["der"] += metric_dict[component] + + return metric_dict + # TODO: no need to compute gradient in this method def validation_step(self, batch, batch_idx: int): """Compute validation loss and metric @@ -1079,124 +1281,111 @@ def validation_step(self, batch, batch_idx: int): Parameters ---------- batch : dict of torch.Tensor - Current batch. + current batch. All chunks come from the same + file and are in chronological order batch_idx: int Batch index. """ - # target - target_dia = batch["y_dia"] - # (batch_size, num_frames, num_speakers) + # get reference + reference = batch["annotation"] + num_speakers = len(reference.labels()) - waveform = batch["X"] - # (batch_size, num_channels, num_samples) + frames = self.model.receptive_field - # TODO: should we handle validation samples with too many speakers - # waveform = waveform[keep] - # target = target[keep] + start_times = batch["start_times"] - # forward pass - dia_prediction, _ = self.model(waveform) - batch_size, num_frames, _ = dia_prediction.shape + file_id = batch["meta"]["file"][0] + file = self.get_file(file_id) + file["annotation"] = reference + print(reference.uri) + print(file["audio"]) + + assert reference.uri in file["audio"] + + # build support timeline from chunk segments + support = Timeline() + for start_time in start_times: + support.add(Segment(start_time, start_time + self.duration)) + print("support=", support) + + # keep reference only on chunk segments: + reference = reference.crop(support) + # corner case where no reference segments intersects the timeline + if len(reference) == 0: + return None - multilabel = self.model.powerset.to_multilabel(dia_prediction) - permutated_target, _ = permutate(multilabel, target_dia) + waveform = batch["X"] + #shape: (num_chunks, num_channels, local_num_samples) - # FIXME: handle case where target have too many speakers? - # since we don't need - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - dia_prediction, - permutated_target_powerset, - ) + # segmentation + embeddings extraction step + segmentations, embeddings = self.model(waveform) + # shapes: (num_chunks, num_frames, powerset_classes), (num_chunks, local_num_speakers, embed_dim) - self.model.log( - "loss/val/dia", - seg_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) + if self.batch_size > 1: + step = batch["start_times"][1] - batch["start_times"][0] + else: + step = self.duration - self.model.validation_metric( - torch.transpose(multilabel, 1, 2), - torch.transpose(target_dia, 1, 2), + sliding_window = SlidingWindow( + start=batch["start_times"][0], duration=self.duration, step=step ) - self.model.log_dict( - self.model.validation_metric, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) + binarized_segmentations = self.model.powerset.to_multilabel(segmentations) - # log first batch visualization every 2^n epochs. - if ( - self.model.current_epoch == 0 - or math.log2(self.model.current_epoch) % 1 > 0 - or batch_idx > 0 - ): - return + binarized_segmentations = binarized_segmentations.cpu().detach().numpy() + binarized_segmentations = SlidingWindowFeature( + binarized_segmentations, sliding_window + ) - # visualize first 9 validation samples of first batch in Tensorboard/MLflow + embeddings = embeddings.cpu().detach().numpy() - y = permutated_target.float().cpu().numpy() - y_pred = multilabel.cpu().numpy() + # clustering step + clustering = KMeansClustering() + hard_clusters, _, _ = clustering( + embeddings=embeddings, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + ) + oracle_clustering = OracleClustering() + oracle_hard_clusters, _, _ = oracle_clustering( + segmentations=binarized_segmentations, + file=file, + frames=self.model.receptive_field.step, + ) - # prepare 3 x 3 grid (or smaller if batch size is smaller) - num_samples = min(self.batch_size, 9) - nrows = math.ceil(math.sqrt(num_samples)) - ncols = math.ceil(num_samples / nrows) - fig, axes = plt.subplots( - nrows=2 * nrows, ncols=ncols, figsize=(8, 5), squeeze=False + pad_duration = step - self.duration + der = self.compute_metric( + reference=reference, + hypothesis=(binarized_segmentations, hard_clusters), + pad_duration=pad_duration, ) + # oder = self.compute_metric( + # reference=reference, + # hypothesis=(binarized_segmentations, oracle_hard_clusters), + # pad_duration=pad_duration, + # ) + + for key in der: + self.model.log( + f"BS={self.batch_size}-Duration={self.duration}s/DER/{key}", + der[key], + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) - # reshape target so that there is one line per class when plotting it - y[y == 0] = np.NaN - if len(y.shape) == 2: - y = y[:, :, np.newaxis] - y *= np.arange(y.shape[2]) - - # plot each sample - for sample_idx in range(num_samples): - # find where in the grid it should be plotted - row_idx = sample_idx // nrows - col_idx = sample_idx % ncols - - # plot target - ax_ref = axes[row_idx * 2 + 0, col_idx] - sample_y = y[sample_idx] - ax_ref.plot(sample_y) - ax_ref.set_xlim(0, len(sample_y)) - ax_ref.set_ylim(-1, sample_y.shape[1]) - ax_ref.get_xaxis().set_visible(False) - ax_ref.get_yaxis().set_visible(False) - - # plot predictions - ax_hyp = axes[row_idx * 2 + 1, col_idx] - sample_y_pred = y_pred[sample_idx] - ax_hyp.plot(sample_y_pred) - ax_hyp.set_ylim(-0.1, 1.1) - ax_hyp.set_xlim(0, len(sample_y)) - ax_hyp.get_xaxis().set_visible(False) - - plt.tight_layout() - - for logger in self.model.loggers: - if isinstance(logger, TensorBoardLogger): - logger.experiment.add_figure("samples", fig, self.model.current_epoch) - elif isinstance(logger, MLFlowLogger): - logger.experiment.log_figure( - run_id=logger.run_id, - figure=fig, - artifact_file=f"samples_epoch{self.model.current_epoch}.png", - ) + # self.model.log( + # f"BS={self.batch_size}-Duration={self.duration}s/ODER/{key}", + # oder[key], + # on_step=False, + # on_epoch=True, + # prog_bar=True, + # logger=True, + # ) - plt.close(fig) + return None def default_metric( self, @@ -1209,6 +1398,4 @@ def default_metric( "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - # "EqualErrorRate": EqualErrorRate(compute_on_cpu=True, distances=False), - # "BinaryAUROC": BinaryAUROC(compute_on_cpu=True), } From 5e541088f03f4678e602d0fe811cdb84cc735d6e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 25 Oct 2024 15:33:13 +0200 Subject: [PATCH 79/83] clean validation pipeline code --- .../speaker_diarization_and_embedding.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 866df830e..578da632f 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -1298,8 +1298,6 @@ def validation_step(self, batch, batch_idx: int): file_id = batch["meta"]["file"][0] file = self.get_file(file_id) file["annotation"] = reference - print(reference.uri) - print(file["audio"]) assert reference.uri in file["audio"] @@ -1307,7 +1305,6 @@ def validation_step(self, batch, batch_idx: int): support = Timeline() for start_time in start_times: support.add(Segment(start_time, start_time + self.duration)) - print("support=", support) # keep reference only on chunk segments: reference = reference.crop(support) @@ -1360,11 +1357,12 @@ def validation_step(self, batch, batch_idx: int): hypothesis=(binarized_segmentations, hard_clusters), pad_duration=pad_duration, ) - # oder = self.compute_metric( - # reference=reference, - # hypothesis=(binarized_segmentations, oracle_hard_clusters), - # pad_duration=pad_duration, - # ) + + oder = self.compute_metric( + reference=reference, + hypothesis=(binarized_segmentations, oracle_hard_clusters), + pad_duration=pad_duration, + ) for key in der: self.model.log( @@ -1376,14 +1374,14 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - # self.model.log( - # f"BS={self.batch_size}-Duration={self.duration}s/ODER/{key}", - # oder[key], - # on_step=False, - # on_epoch=True, - # prog_bar=True, - # logger=True, - # ) + self.model.log( + f"BS={self.batch_size}-Duration={self.duration}s/ODER/{key}", + oder[key], + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) return None From 9b8e509d9cf111b731af28ec124dd902ba1fba77 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 28 Oct 2024 15:24:12 +0100 Subject: [PATCH 80/83] handle overlaped segmentation chunks corner case --- .../speaker_diarization_and_embedding.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 578da632f..e1837b926 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -1177,15 +1177,18 @@ def reconstruct( def aggregate(self, segmentations: SlidingWindowFeature, pad_duration:float) -> SlidingWindowFeature: num_chunks, num_frames, num_speakers = segmentations.data.shape - sliding_window = segmentations.sliding_window - frame_duration = sliding_window.duration / num_frames + frame_duration = segmentations.sliding_window.duration / num_frames - if num_chunks <= 1: - return segmentations[0] + window = SlidingWindow(step=frame_duration, duration=frame_duration) - num_padding_frames = np.round( - pad_duration / frame_duration - ).astype(np.uint32) + if num_chunks == 1: + return SlidingWindowFeature(segmentations[0], window) + + # if segmentation chunks are overlaped + if pad_duration < 0.: + return Inference.aggregate(segmentations, window) + + num_padding_frames = np.round(pad_duration / frame_duration).astype(np.uint32) aggregated_segmentation = segmentations[0] for chunk_segmentation in segmentations[1:]: @@ -1193,7 +1196,8 @@ def aggregate(self, segmentations: SlidingWindowFeature, pad_duration:float) -> aggregated_segmentation = np.concatenate( (aggregated_segmentation, padding, chunk_segmentation) ) - return SlidingWindowFeature(aggregated_segmentation.astype(np.int8), SlidingWindow(step=frame_duration, duration=frame_duration)) + + return SlidingWindowFeature(aggregated_segmentation.astype(np.int8), window) def to_diarization( self, From 7708935d034a3b06f7fcaa809ca96b28799f392b Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 28 Oct 2024 16:11:14 +0100 Subject: [PATCH 81/83] add some comments --- .../speaker_diarization_and_embedding.py | 144 ++++++++++++------ 1 file changed, 94 insertions(+), 50 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index e1837b926..6fd38c750 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -46,24 +46,17 @@ from scipy.spatial.distance import cdist +from pyannote.audio import Inference +from pyannote.audio.core.io import Audio from pyannote.audio.core.task import Problem, Resolution, Specifications, get_dtype from pyannote.audio.tasks import SpeakerDiarization -from pyannote.audio.torchmetrics import ( - DiarizationErrorRate, - FalseAlarmRate, - MissedDetectionRate, - SpeakerConfusionRate, -) from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.random import create_rng_for_worker from pyannote.audio.pipelines.clustering import KMeansClustering, OracleClustering from pyannote.audio.pipelines.utils import SpeakerDiarizationMixin -from pyannote.audio.core.io import Audio -from pyannote.metrics.diarization import ( - DiarizationErrorRate as GlobalDiarizationErrorRate, -) +from pyannote.metrics.diarization import DiarizationErrorRate Subtask = Literal["diarization", "embedding"] @@ -781,10 +774,24 @@ def train__iter__(self): yield next(chunks) def val__getitem__(self, idx) -> Dict: - """Validation items are generated so that all samples in a batch come from the same - validation file. These samples are created by sliding a window over the first seconds of - the validation file, with a step (for now arbitrally) set to 0.2 (20% of the task duration, - e.g. 1 second for a duration of 5 seconds)""" + """Validation items are generated so that all the chunks in a batch come from the same + validation file. These chunks are extracted regularly over all the file, so that the first + chunk start at the very beginning of the file, and the last chunk ends at the end of the file. + Step between chunks depends of the file duration and the total batch duration. This step can + be negative. In that case, chunks are overlapped. + + Parameters + ---------- + idx: int + item index. Note: this method may be incompatible with the use of sampler, + as this method requires incremental idx starting from 0. + + Returns + ------- + chunk: dict + extracted chunk from current validation file. The first chunk contains annotation + for the whole file. + """ file_idx = idx // self.batch_size chunk_idx = idx % self.batch_size @@ -811,6 +818,7 @@ def val__getitem__(self, idx) -> Dict: return chunk def val__len__(self): + """Return total length of validation, which is num_validation_files * batch_size""" return len(self.prepared_data["validation-files"]) * self.batch_size def collate_y(self, batch) -> torch.Tensor: @@ -892,7 +900,8 @@ def collate_fn(self, batch, stage="train"): Returns ------- batch : dict - Collated batch as {"X": torch.Tensor, "y": torch.Tensor} dict. + Collated batch as {"X": torch.Tensor, "y": torch.Tensor} dict (train). + Collated batch as {"X": torch.Tensor, "annotation": Annotation, "start_times": list} dict (validation) """ # collate X @@ -1139,10 +1148,8 @@ def reconstruct( ---------- segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature Raw speaker segmentation. - hard_clusters : (num_chunks, num_speakers) array + clusters : (num_chunks, num_speakers) array Output of clustering step. - count : (total_num_frames, 1) SlidingWindowFeature - Instantaneous number of active speakers. Returns ------- @@ -1173,9 +1180,30 @@ def reconstruct( clustered_segmentations = SlidingWindowFeature( clustered_segmentations, segmentations.sliding_window ) + return clustered_segmentations - def aggregate(self, segmentations: SlidingWindowFeature, pad_duration:float) -> SlidingWindowFeature: + def aggregate( + self, segmentations: SlidingWindowFeature, pad_duration: float + ) -> SlidingWindowFeature: + """"Aggregate segmentation chunks over time with padding between + each chunk. + + Parameters + ---------- + segmentations: SlidingWindowFeature + unaggragated segmentation chunks. Shape is (num_chunks, num_frames, num_speakers). + pad_duration: float + padding duration between two consecutive segmentation chunks. + In case pad_duration is < 0. (overlapped segmentation chunks), + `Inference.aggregate()` is used. + + Returns + ------- + segmentation: SlidingWindowFeature + aggregated segmentation. Shape is (num_frames', num_speakers). + """ + num_chunks, num_frames, num_speakers = segmentations.data.shape frame_duration = segmentations.sliding_window.duration / num_frames @@ -1202,21 +1230,22 @@ def aggregate(self, segmentations: SlidingWindowFeature, pad_duration:float) -> def to_diarization( self, segmentations: SlidingWindowFeature, - pad_duration: float = 0., + pad_duration: float = 0.0, ) -> SlidingWindowFeature: - """Build diarization out of preprocessed segmentation and precomputed speaker count + """Build diarization out of preprocessed segmentation Parameters ---------- segmentations : SlidingWindowFeature (num_chunks, num_frames, num_speakers)-shaped segmentations - count : SlidingWindow_feature - (num_frames, 1)-shaped speaker count + pad_duration: float, + padding duration between two consecutive segmentation chunks. + Can be negative (overlapped chunks). Returns ------- discrete_diarization : SlidingWindowFeature - Discrete (0s and 1s) diarization. + Discrete (0s and 1s) diarization. Shape is (num_frames', num_speakers') """ activations = self.aggregate(segmentations, pad_duration=pad_duration) @@ -1245,13 +1274,29 @@ def to_diarization( return SlidingWindowFeature(binary, activations.sliding_window) - def compute_metric( + def compute_der( self, reference: Annotation, hypothesis: Tuple[SlidingWindowFeature, np.ndarray], pad_duration: float, ): - """Compute diarization annotation from binarized segmentation and cluster (num_chunk, num_speaker)""" + """ Compute global Diarization Error Rate (DER) given reference and hypothesis. + DER is computed on part of the validation file that were used to build validation + batch. + + Parameters + ---------- + reference: pyannote.core.Annotation + cropped file's annotation matching part of the file used to build validation batch + hypothesis: (pyannote.core.SlidingWindowFeature, np.ndarray) + tuple containing unclustered segmentation chunks and clusters from the clustering step + pad_duration: float + padding duration between two consecutives chunks. Can be negative (overlapped chunks) + + Returns: dict + Dict containing computed DER and its components (false alarm, missed detection + and confusion) + """ frames = self.model.receptive_field binarized_segmentations, clusters = hypothesis @@ -1260,27 +1305,27 @@ def compute_metric( # shape: (num_chunks, num_speakers) clusters[inactive_speakers] = -2 - clustered_segmentations = self.reconstruct( - binarized_segmentations, clusters - ) + clustered_segmentations = self.reconstruct(binarized_segmentations, clusters) - binarized_diarization = self.to_diarization(clustered_segmentations, pad_duration=pad_duration) + binarized_diarization = self.to_diarization( + clustered_segmentations, pad_duration=pad_duration + ) diarization = SpeakerDiarizationMixin.to_annotation(binarized_diarization) - metric = GlobalDiarizationErrorRate() + metric = DiarizationErrorRate() metric(reference, diarization, detailed=True) result = metric[:] - metric_dict = {"der": 0.} + metric_dict = {"der": 0.0} for component in ["false alarm", "missed detection", "confusion"]: - metric_dict[component] = (result[component] / result["total"]) + metric_dict[component] = result[component] / result["total"] metric_dict["der"] += metric_dict[component] return metric_dict # TODO: no need to compute gradient in this method def validation_step(self, batch, batch_idx: int): - """Compute validation loss and metric + """Compute (global) Diarization Error Rate Parameters ---------- @@ -1312,12 +1357,13 @@ def validation_step(self, batch, batch_idx: int): # keep reference only on chunk segments: reference = reference.crop(support) - # corner case where no reference segments intersects the timeline + # corner case where no reference segments intersects the timeline. + # This case can occurs if batch duration is too short. if len(reference) == 0: return None waveform = batch["X"] - #shape: (num_chunks, num_channels, local_num_samples) + # shape: (num_chunks, num_channels, local_num_samples) # segmentation + embeddings extraction step segmentations, embeddings = self.model(waveform) @@ -1332,39 +1378,42 @@ def validation_step(self, batch, batch_idx: int): start=batch["start_times"][0], duration=self.duration, step=step ) + # convert powert segmentations to multilabel segmentation binarized_segmentations = self.model.powerset.to_multilabel(segmentations) + # gradient is uneeded here, so we can safely detach tensors from the gradient graph binarized_segmentations = binarized_segmentations.cpu().detach().numpy() + embeddings = embeddings.cpu().detach().numpy() + binarized_segmentations = SlidingWindowFeature( binarized_segmentations, sliding_window ) - embeddings = embeddings.cpu().detach().numpy() - # clustering step clustering = KMeansClustering() - hard_clusters, _, _ = clustering( + clusters, _, _ = clustering( embeddings=embeddings, segmentations=binarized_segmentations, num_clusters=num_speakers, ) oracle_clustering = OracleClustering() - oracle_hard_clusters, _, _ = oracle_clustering( + oracle_clusters, _, _ = oracle_clustering( segmentations=binarized_segmentations, file=file, frames=self.model.receptive_field.step, ) + # compute diarization error rate pad_duration = step - self.duration - der = self.compute_metric( + der = self.compute_der( reference=reference, - hypothesis=(binarized_segmentations, hard_clusters), + hypothesis=(binarized_segmentations, clusters), pad_duration=pad_duration, ) - oder = self.compute_metric( + oder = self.compute_der( reference=reference, - hypothesis=(binarized_segmentations, oracle_hard_clusters), + hypothesis=(binarized_segmentations, oracle_clusters), pad_duration=pad_duration, ) @@ -1395,9 +1444,4 @@ def default_metric( """Returns diarization error rate and its components for diarization subtask, and equal error rate for the embedding part """ - return { - "DiarizationErrorRate": DiarizationErrorRate(0.5), - "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), - "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), - "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - } + return {} From 2edebc4cb2f9e1e8e97bdd2c16c413e2d0858335 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 18 Nov 2024 14:37:21 +0100 Subject: [PATCH 82/83] replace `pyannote.metrics` DER by `pyannote.audio.torchmetrics` one This is done to use the same metrics as for other pyannote's tasks, and to benefit from lightning advantages (parallelization...) --- .../speaker_diarization_and_embedding.py | 370 +++++++++--------- 1 file changed, 185 insertions(+), 185 deletions(-) diff --git a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py index 6fd38c750..c0cb49aac 100644 --- a/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py +++ b/pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py @@ -31,7 +31,6 @@ import numpy as np import torch from einops import rearrange -from matplotlib import pyplot as plt from pyannote.core import ( Annotation, Segment, @@ -56,7 +55,14 @@ from pyannote.audio.pipelines.clustering import KMeansClustering, OracleClustering from pyannote.audio.pipelines.utils import SpeakerDiarizationMixin -from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +from torchmetrics import MetricCollection Subtask = Literal["diarization", "embedding"] @@ -126,6 +132,17 @@ def __init__( # * diarization databases are those with file or database speaker label scope self.embedding_files_id = [] + self.validation_metrics = MetricCollection( + { + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + } + ) + + self.oracle_validation_metrics = self.validation_metrics.clone(prefix="Oracle") + def prepare_data(self): """Use this to prepare data from task protocol @@ -785,7 +802,7 @@ def val__getitem__(self, idx) -> Dict: idx: int item index. Note: this method may be incompatible with the use of sampler, as this method requires incremental idx starting from 0. - + Returns ------- chunk: dict @@ -1159,9 +1176,7 @@ def reconstruct( num_chunks, num_frames, _ = segmentations.data.shape num_clusters = np.max(clusters) + 1 - clustered_segmentations = np.nan * np.zeros( - (num_chunks, num_frames, num_clusters) - ) + clustered_segmentations = np.zeros((num_chunks, num_frames, num_clusters)) for c, (cluster, (chunk, segmentation)) in enumerate( zip(clusters, segmentations) @@ -1183,122 +1198,120 @@ def reconstruct( return clustered_segmentations - def aggregate( - self, segmentations: SlidingWindowFeature, pad_duration: float - ) -> SlidingWindowFeature: - """"Aggregate segmentation chunks over time with padding between - each chunk. - - Parameters - ---------- - segmentations: SlidingWindowFeature - unaggragated segmentation chunks. Shape is (num_chunks, num_frames, num_speakers). - pad_duration: float - padding duration between two consecutive segmentation chunks. - In case pad_duration is < 0. (overlapped segmentation chunks), - `Inference.aggregate()` is used. - - Returns - ------- - segmentation: SlidingWindowFeature - aggregated segmentation. Shape is (num_frames', num_speakers). - """ - - num_chunks, num_frames, num_speakers = segmentations.data.shape - frame_duration = segmentations.sliding_window.duration / num_frames - - window = SlidingWindow(step=frame_duration, duration=frame_duration) - - if num_chunks == 1: - return SlidingWindowFeature(segmentations[0], window) - - # if segmentation chunks are overlaped - if pad_duration < 0.: - return Inference.aggregate(segmentations, window) - - num_padding_frames = np.round(pad_duration / frame_duration).astype(np.uint32) - aggregated_segmentation = segmentations[0] - - for chunk_segmentation in segmentations[1:]: - padding = np.zeros((num_padding_frames, num_speakers)) - aggregated_segmentation = np.concatenate( - (aggregated_segmentation, padding, chunk_segmentation) - ) - - return SlidingWindowFeature(aggregated_segmentation.astype(np.int8), window) - - def to_diarization( + # def aggregate( + # self, segmentations: SlidingWindowFeature, pad_duration: float + # ) -> SlidingWindowFeature: + # """ "Aggregate segmentation chunks over time with padding between + # each chunk. + + # Parameters + # ---------- + # segmentations: SlidingWindowFeature + # unaggragated segmentation chunks. Shape is (num_chunks, num_frames, num_speakers). + # pad_duration: float + # padding duration between two consecutive segmentation chunks. + # In case pad_duration is < 0. (overlapped segmentation chunks), + # `Inference.aggregate()` is used. + + # Returns + # ------- + # segmentation: SlidingWindowFeature + # aggregated segmentation. Shape is (num_frames', num_speakers). + # """ + + # num_chunks, num_frames, num_speakers = segmentations.data.shape + # frame_duration = segmentations.sliding_window.duration / num_frames + + # window = SlidingWindow(step=frame_duration, duration=frame_duration) + + # if num_chunks == 1: + # return SlidingWindowFeature(segmentations[0], window) + + # # if segmentation chunks are overlaped + # if pad_duration < 0.0: + # return Inference.aggregate(segmentations, window) + + # num_padding_frames = np.round(pad_duration / frame_duration).astype(np.uint32) + # aggregated_segmentation = segmentations[0] + + # for chunk_segmentation in segmentations[1:]: + # padding = np.zeros((num_padding_frames, num_speakers)) + # aggregated_segmentation = np.concatenate( + # (aggregated_segmentation, padding, chunk_segmentation) + # ) + + # return SlidingWindowFeature(aggregated_segmentation.astype(np.int8), window) + + # def to_diarization( + # self, + # segmentations: SlidingWindowFeature, + # pad_duration: float = 0.0, + # ) -> SlidingWindowFeature: + # """Build diarization out of preprocessed segmentation + + # Parameters + # ---------- + # segmentations : SlidingWindowFeature + # (num_chunks, num_frames, num_speakers)-shaped segmentations + # pad_duration: float, + # padding duration between two consecutive segmentation chunks. + # Can be negative (overlapped chunks). + + # Returns + # ------- + # discrete_diarization : SlidingWindowFeature + # Discrete (0s and 1s) diarization. Shape is (num_frames', num_speakers') + # """ + + # activations = self.aggregate(segmentations, pad_duration=pad_duration) + # # shape: (num_frames, num_speakers) + # _, num_speakers = activations.data.shape + + # count = np.sum(activations, axis=1, keepdims=True) + # # shape: (num_frames, 1) + + # max_speakers_per_frame = np.max(count.data) + # if num_speakers < max_speakers_per_frame: + # activations.data = np.pad( + # activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers)) + # ) + + # extent = activations.extent & count.extent + # activations = activations.crop(extent, return_data=False) + # count = count.crop(extent, return_data=False) + + # sorted_speakers = np.argsort(-activations, axis=-1) + # binary = np.zeros_like(activations.data) + + # for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)): + # for i in range(c.item()): + # binary[t, speakers[i]] = 1.0 + + # return SlidingWindowFeature(binary, activations.sliding_window) + + def compute_metrics( self, - segmentations: SlidingWindowFeature, - pad_duration: float = 0.0, - ) -> SlidingWindowFeature: - """Build diarization out of preprocessed segmentation - - Parameters - ---------- - segmentations : SlidingWindowFeature - (num_chunks, num_frames, num_speakers)-shaped segmentations - pad_duration: float, - padding duration between two consecutive segmentation chunks. - Can be negative (overlapped chunks). - - Returns - ------- - discrete_diarization : SlidingWindowFeature - Discrete (0s and 1s) diarization. Shape is (num_frames', num_speakers') - """ - - activations = self.aggregate(segmentations, pad_duration=pad_duration) - # shape: (num_frames, num_speakers) - _, num_speakers = activations.data.shape - - count = np.sum(activations, axis=1, keepdims=True) - # shape: (num_frames, 1) - - max_speakers_per_frame = np.max(count.data) - if num_speakers < max_speakers_per_frame: - activations.data = np.pad( - activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers)) - ) - - extent = activations.extent & count.extent - activations = activations.crop(extent, return_data=False) - count = count.crop(extent, return_data=False) - - sorted_speakers = np.argsort(-activations, axis=-1) - binary = np.zeros_like(activations.data) - - for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)): - for i in range(c.item()): - binary[t, speakers[i]] = 1.0 - - return SlidingWindowFeature(binary, activations.sliding_window) - - def compute_der( - self, - reference: Annotation, - hypothesis: Tuple[SlidingWindowFeature, np.ndarray], - pad_duration: float, - ): - """ Compute global Diarization Error Rate (DER) given reference and hypothesis. - DER is computed on part of the validation file that were used to build validation - batch. + discretized_reference, + prediction: Tuple[SlidingWindowFeature, np.ndarray], + oracle_mode: bool, + ) -> None: + """Compute (oracle) Diarization Error Rate at file level + given the reference and hypothesis. + DER is only computed on parts of the validation file + that were used to build validation batch. Parameters ---------- - reference: pyannote.core.Annotation - cropped file's annotation matching part of the file used to build validation batch - hypothesis: (pyannote.core.SlidingWindowFeature, np.ndarray) - tuple containing unclustered segmentation chunks and clusters from the clustering step - pad_duration: float - padding duration between two consecutives chunks. Can be negative (overlapped chunks) - - Returns: dict - Dict containing computed DER and its components (false alarm, missed detection - and confusion) + discretized_reference: np.ndarray + cropped file's discretized reference matching parts + of the validation file used to build current batch + prediction: (pyannote.core.SlidingWindowFeature, np.ndarray) + tuple containing unclustered segmentation chunks + and clusters from the clustering step + oracle_mode: boolean + Whether to compute DER or oracle DER """ - frames = self.model.receptive_field - binarized_segmentations, clusters = hypothesis + binarized_segmentations, clusters = prediction # keep track of inactive speakers inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 @@ -1306,26 +1319,31 @@ def compute_der( clusters[inactive_speakers] = -2 clustered_segmentations = self.reconstruct(binarized_segmentations, clusters) + # shape: (num_chunks, num_frames, num_speakers) - binarized_diarization = self.to_diarization( - clustered_segmentations, pad_duration=pad_duration - ) - diarization = SpeakerDiarizationMixin.to_annotation(binarized_diarization) - - metric = DiarizationErrorRate() - metric(reference, diarization, detailed=True) + clustered_segmentations = torch.from_numpy(clustered_segmentations.data) + hypothesis = rearrange(clustered_segmentations, "c f s -> s (c f)") + # shape: (num_speakers, num_chunks * num_frames) - result = metric[:] - metric_dict = {"der": 0.0} - for component in ["false alarm", "missed detection", "confusion"]: - metric_dict[component] = result[component] / result["total"] - metric_dict["der"] += metric_dict[component] + reference = torch.from_numpy(discretized_reference.T) + # shape: (num_speakers, num_chunks * num_frames) - return metric_dict + # calculate and log metrics + name = "oracle_validation_metrics" if oracle_mode else "validation_metrics" + metrics = getattr(self, name) + outputs = metrics(hypothesis.unsqueeze(0), reference.unsqueeze(0)) + self.model.log_dict( + outputs, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) - # TODO: no need to compute gradient in this method def validation_step(self, batch, batch_idx: int): - """Compute (global) Diarization Error Rate + """Validation step consists of applying a diarization pipeline + on a validation file to compute file-level Diarization Error Rate (DER) + and Oracle Diarization Error Rate (ODER). Parameters ---------- @@ -1340,27 +1358,31 @@ def validation_step(self, batch, batch_idx: int): reference = batch["annotation"] num_speakers = len(reference.labels()) - frames = self.model.receptive_field - start_times = batch["start_times"] file_id = batch["meta"]["file"][0] file = self.get_file(file_id) + # needed by oracle clustering file["annotation"] = reference assert reference.uri in file["audio"] - # build support timeline from chunk segments - support = Timeline() - for start_time in start_times: - support.add(Segment(start_time, start_time + self.duration)) + resolution = self.model.receptive_field - # keep reference only on chunk segments: - reference = reference.crop(support) - # corner case where no reference segments intersects the timeline. - # This case can occurs if batch duration is too short. - if len(reference) == 0: - return None + # get discretized reference for current file + discretized_segments = [] + num_frames = int( + self.model.num_frames(self.model.hparams["sample_rate"] * self.duration) + ) + for start_time in start_times: + discretized_segment = reference.discretize( + support=Segment(start_time, start_time + self.duration), + resolution=resolution, + labels=reference.labels(), + ) + discretized_segments.append(discretized_segment.data[:num_frames]) + discretized_reference = np.concatenate(discretized_segments) + # shape: (num_chunks * num_frames, num_speakers) waveform = batch["X"] # shape: (num_chunks, num_channels, local_num_samples) @@ -1374,68 +1396,46 @@ def validation_step(self, batch, batch_idx: int): else: step = self.duration - sliding_window = SlidingWindow( - start=batch["start_times"][0], duration=self.duration, step=step - ) - - # convert powert segmentations to multilabel segmentation + # convert from powerset segmentations to multilabel segmentations binarized_segmentations = self.model.powerset.to_multilabel(segmentations) - # gradient is uneeded here, so we can safely detach tensors from the gradient graph + # gradient is uneeded here, so we can detach tensors from the gradient graph binarized_segmentations = binarized_segmentations.cpu().detach().numpy() embeddings = embeddings.cpu().detach().numpy() binarized_segmentations = SlidingWindowFeature( - binarized_segmentations, sliding_window + binarized_segmentations, + SlidingWindow( + start=batch["start_times"][0], duration=self.duration, step=step + ), ) - # clustering step + # compute file-wise diarization error rate clustering = KMeansClustering() clusters, _, _ = clustering( embeddings=embeddings, segmentations=binarized_segmentations, num_clusters=num_speakers, ) + der = self.compute_metrics( + discretized_reference=discretized_reference, + prediction=(binarized_segmentations, clusters), + oracle_mode=False, + ) + + # compute file-wise oracle diarization error rate oracle_clustering = OracleClustering() oracle_clusters, _, _ = oracle_clustering( segmentations=binarized_segmentations, file=file, - frames=self.model.receptive_field.step, - ) - - # compute diarization error rate - pad_duration = step - self.duration - der = self.compute_der( - reference=reference, - hypothesis=(binarized_segmentations, clusters), - pad_duration=pad_duration, + frames=resolution.step, ) - - oder = self.compute_der( - reference=reference, - hypothesis=(binarized_segmentations, oracle_clusters), - pad_duration=pad_duration, + oder = self.compute_metrics( + discretized_reference=discretized_reference, + prediction=(binarized_segmentations, oracle_clusters), + oracle_mode=True, ) - for key in der: - self.model.log( - f"BS={self.batch_size}-Duration={self.duration}s/DER/{key}", - der[key], - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - - self.model.log( - f"BS={self.batch_size}-Duration={self.duration}s/ODER/{key}", - oder[key], - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return None def default_metric( From 76d4ec9755d8749ea44f7a0d8313707b4dc62836 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Mon, 18 Nov 2024 14:46:18 +0100 Subject: [PATCH 83/83] update joint pipeline --- pyannote/audio/pipelines/speaker_diarization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 7e89a70fc..e051e3741 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -27,7 +27,7 @@ import math import textwrap import warnings -from typing import Callable, Mapping, Optional, Text, Union +from typing import Callable, Mapping, Optional, Text, Union, Tuple import numpy as np import torch @@ -714,6 +714,7 @@ def __init__( clustering: str = "AgglomerativeClustering", batch_size: int = 1, use_auth_token: Union[Text, None] = None, + der_variant: Optional[dict] = None, ): super().__init__() @@ -731,6 +732,8 @@ def __init__( self.step = step self.klustering = clustering + self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False} + duration: float = segmentation_specifications.duration self._inference = Inference( model,