Skip to content

Commit

Permalink
add end-to-end diarization and embedding model
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-pages committed Jun 20, 2023
1 parent dfdd8f3 commit 1888360
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions pyannote/audio/models/joint/end_to_end_diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# MIT License
#
# Copyright (c) 2020 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Optional
from warnings import warn
from einops import rearrange

import torch
from torch import nn
import torch.nn.functional as F

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.models.blocks.sincnet import SincNet
from pyannote.audio.models.blocks.pooling import StatsPool
from pyannote.audio.utils.params import merge_dict
from pyannote.core.utils.generators import pairwise

class SpeakerEndToEndDiarization(Model):
"""Speaker End-to-End Diarization and Embedding model
SINCNET -- TDNN .. TDNN -- TDNN ..TDNN -- StatsPool -- Linear -- Classifier
\ LSTM ... LSTM -- FeedForward -- Classifier
"""
SINCNET_DEFAULTS = {"stride": 10}
LSTM_DEFAULTS = {
"hidden_size": 128,
"num_layers": 2,
"bidirectional": True,
"monolithic": True,
"dropout": 0.0,
"batch_first": True,
}
LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2}

def __init__(
self,
sincnet: dict = None,
lstm: dict= None,
linear: dict = None,
sample_rate: int = 16000,
num_channels: int = 1,
num_features: int = 60,
embedding_dim: int = 512,
separation_idx: int = 2,
task: Optional[Task] = None,
):
super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)

if num_features != 60:
warn("For now, the model only support a number of features of 60. Set it to 60")
num_features = 60
self.num_features = num_features
self.separation_idx = separation_idx
self.save_hyperparameters("num_features", "embedding_dim", "separation_idx")


# sincnet module
sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet)
sincnet["sample_rate"] = sample_rate
self.sincnet =SincNet(**sincnet)
self.save_hyperparameters("sincnet")

# tdnn modules
self.tdnn_blocks = nn.ModuleList()
in_channel = num_features
out_channels = [512, 512, 512, 512, 1500]
kernel_sizes = [5, 3, 3, 1, 1]
dilations = [1, 2, 3, 1, 1]

for out_channel, kernel_size, dilation in zip(
out_channels, kernel_sizes, dilations
):
self.tdnn_blocks.extend(
[
nn.Sequential(
nn.Conv1d(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
dilation=dilation,
),
nn.LeakyReLU(),
nn.BatchNorm1d(out_channel),
),
]
)
in_channel = out_channel

# lstm modules:
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("lstm")
monolithic = lstm["monolithic"]
if monolithic:
multi_layer_lstm = dict(lstm)
del multi_layer_lstm["monolithic"]
self.lstm = nn.LSTM(out_channels[separation_idx], **multi_layer_lstm)
else:
num_layers = lstm["num_layers"]
if num_layers > 1:
self.dropout = nn.Dropout(p=lstm["dropout"])

one_layer_lstm = dict(lstm)
del one_layer_lstm["monolithic"]
one_layer_lstm["num_layers"] = 1
one_layer_lstm["dropout"] = 0.0

self.lstm = nn.ModuleList(
[
nn.LSTM(
out_channels[separation_idx]
if i == 0
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
**one_layer_lstm
)
for i in range(num_layers)
]
)

# linear module for the diarization part:
linear = merge_dict(self.LINEAR_DEFAULTS, linear)
self.save_hyperparameters("linear")
if linear["num_layers"] < 1:
return

lstm_out_features: int = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)
self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
for in_features, out_features in pairwise(
[
lstm_out_features,
]
+ [self.hparams.linear["hidden_size"]]
* self.hparams.linear["num_layers"]
)
]
)

# stats pooling module for the embedding part:
self.stats_pool = StatsPool()
# linear module for the embedding part:
self.embedding = nn.Linear(in_channel * 2, embedding_dim)



def build(self):
if self.hparams.linear["num_layers"] > 0:
in_features = self.hparams.linear["hidden_size"]
else:
in_features = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)

out_features = self.specifications.num_powerset_classes

self.classifier = nn.Linear(in_features, out_features)
self.activation = self.default_activation()


def forward(self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Parameters
----------
waveforms : torch.Tensor
Batch of waveforms with shape (batch, channel, sample)
weights : torch.Tensor, optional
Batch of weights wiht shape (batch, frame)
"""
common_outputs = self.sincnet(waveforms)
# (batch, features, frames)
# common part to diarization and embedding:
tdnn_idx = 0
while tdnn_idx <= self.separation_idx:
common_outputs = self.tdnn_blocks[tdnn_idx](common_outputs)
tdnn_idx = tdnn_idx + 1
# diarization part:
if self.hparams.lstm["monolithic"]:
diarization_outputs, _ = self.lstm(
rearrange(common_outputs, "batch feature frame -> batch frame feature")
)
else:
diarization_outputs = rearrange(common_outputs, "batch feature frame -> batch frame feature")
for i, lstm in enumerate(self.lstm):
diarization_outputs, _ = lstm(diarization_outputs)
if i + 1 < self.hparams.lstm["num_layers"]:
diarization_outputs = self.linear()

if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
diarization_outputs = F.leaky_relu(linear(diarization_outputs))
diarization_outputs = self.classifier(diarization_outputs)
diarization_outputs = self.activation(diarization_outputs)

# embedding part:
embedding_outputs = torch.clone(common_outputs)

This comment has been minimized.

Copy link
@hbredin

hbredin Jun 20, 2023

Member

Don't do that, this might break gradient flow.
Simply use common_outputs directly.

This comment has been minimized.

Copy link
@clement-pages

clement-pages Jun 20, 2023

Author Collaborator

According to pytorch doc, the clone function allows gradient flow, from result to input.I used clone to ensure that modifications applied on embedding_outputs don't affect common_outputs. But here, I can use common_outputs directly, as this tensor is not used anymore in the method.

for tdnn_block in self.tdnn_blocks[tdnn_idx:]:
embedding_outputs = tdnn_block(embedding_outputs)

# TODO : reinject diarization outputs into the pooling layers:
embedding_outputs = self.stats_pool(embedding_outputs)

This comment has been minimized.

Copy link
@hbredin

hbredin Jun 20, 2023

Member

It is already implemented in the other PR, isn't it?

All you have to do is convert from powerset diarization_outputs to multilabel encoding. Probably not as straightforward as it sounds though...

You should use pyannote.audio.utils.powerset.Powerset module for that.
In build,

self.powerset = Powerset(...)

In forward:

weights = self.powerset(diarization_outputs).reshape(...)
embedding_outputs = self.stats_pool(embedding_output, weights=weights)

This comment has been minimized.

Copy link
@clement-pages

clement-pages Jun 20, 2023

Author Collaborator

Thank you for this piece of code !

embedding_outputs = self.embedding(embedding_outputs)

return (diarization_outputs, embedding_outputs)

0 comments on commit 1888360

Please sign in to comment.