-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support to voxblink2 pretrained model and update readme
- Loading branch information
王帅
committed
Sep 3, 2024
1 parent
4be9d57
commit 6b149b3
Showing
11 changed files
with
325 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Research that uses the wespeaker project | ||
|
||
[TOC] | ||
|
||
## Introduction | ||
|
||
After the release of the WeSpeaker project, many users from both academia and industry have actively engaged with it in their research. We appreciate all the feedback and contributions from the community and would like to highlight some of these interesting works below. | ||
|
||
## New Architecture | ||
### ReDimNet | ||
> | ||
- Implementation in Wespeaker | ||
- Paper Link | ||
|
||
``` | ||
@article{yakovlev2024reshape, | ||
title={Reshape Dimensions Network for Speaker Recognition}, | ||
author={Yakovlev, Ivan and Makarov, Rostislav and Balykin, Andrei and Malov, Pavel and Okhotnikov, Anton and Torgashov, Nikita}, | ||
journal={arXiv preprint arXiv:2407.18223}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
### Golden gemini DF-ResNet | ||
|
||
|
||
``` | ||
@article{liu2024golden, | ||
title={Golden gemini is all you need: Finding the sweet spots for speaker verification}, | ||
author={Liu, Tianchi and Lee, Kong Aik and Wang, Qiongqiong and Li, Haizhou}, | ||
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, | ||
year={2024}, | ||
publisher={IEEE} | ||
} | ||
``` | ||
|
||
### SAM-ResNet | ||
|
||
``` | ||
@inproceedings{qin2022simple, | ||
title={Simple attention module based speaker verification with iterative noisy label detection}, | ||
author={Qin, Xiaoyi and Li, Na and Weng, Chao and Su, Dan and Li, Ming}, | ||
booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, | ||
pages={6722--6726}, | ||
year={2022}, | ||
organization={IEEE} | ||
} | ||
``` | ||
### Whisper based Speaker Verification | ||
|
||
|
||
|
||
## Pipelines | ||
|
||
### DINO Pretraining with Large-scale Data | ||
|
||
``` | ||
@inproceedings{wang2024leveraging, | ||
title={Leveraging In-the-Wild Data for Effective Self-Supervised Pretraining in Speaker Recognition}, | ||
author={Wang, Shuai and Bai, Qibing and Liu, Qi and Yu, Jianwei and Chen, Zhengyang and Han, Bing and Qian, Yanmin and Li, Haizhou}, | ||
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, | ||
pages={10901--10905}, | ||
year={2024}, | ||
organization={IEEE} | ||
} | ||
``` | ||
|
||
|
||
## New Recipe/ Dataset | ||
|
||
### VoxBlink | ||
|
||
``` | ||
@inproceedings{lin2024voxblink, | ||
title={Voxblink: A large scale speaker verification dataset on camera}, | ||
author={Lin, Yuke and Qin, Xiaoyi and Zhao, Guoqing and Cheng, Ming and Jiang, Ning and Wu, Haiying and Li, Ming}, | ||
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, | ||
pages={10271--10275}, | ||
year={2024}, | ||
organization={IEEE} | ||
} | ||
@article{lin2024voxblink2, | ||
title={VoxBlink2: A 100K+ Speaker Recognition Corpus and the Open-Set Speaker-Identification Benchmark}, | ||
author={Lin, Yuke and Cheng, Ming and Zhang, Fulin and Gao, Yingying and Zhang, Shilei and Li, Ming}, | ||
journal={arXiv preprint arXiv:2407.11510}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ Reference | |
|
||
./paper.md | ||
./speaker_recognition_papers.md | ||
./papers_using_wespeaker.md | ||
./python_api/modules.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import torch | ||
import torch.nn as nn | ||
import wespeaker.models.pooling_layers as pooling_layers | ||
|
||
|
||
class SimAMBasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__( | ||
self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1 | ||
): | ||
super(SimAMBasicBlock, self).__init__() | ||
self.conv1 = ConvLayer( | ||
in_planes, | ||
planes, | ||
kernel_size=3, | ||
stride=stride, | ||
padding=1, | ||
bias=False, | ||
) | ||
self.bn1 = NormLayer(planes) | ||
self.conv2 = ConvLayer( | ||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False | ||
) | ||
self.bn2 = NormLayer(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
self.downsample = nn.Sequential() | ||
if stride != 1 or in_planes != self.expansion * planes: | ||
self.downsample = nn.Sequential( | ||
ConvLayer( | ||
in_planes, | ||
self.expansion * planes, | ||
kernel_size=1, | ||
stride=stride, | ||
bias=False, | ||
), | ||
NormLayer(self.expansion * planes), | ||
) | ||
|
||
def forward(self, x): | ||
out = self.relu(self.bn1(self.conv1(x))) | ||
out = self.bn2(self.conv2(out)) | ||
out = self.SimAM(out) | ||
out += self.downsample(x) | ||
out = self.relu(out) | ||
return out | ||
|
||
def SimAM(self, X, lambda_p=1e-4): | ||
n = X.shape[2] * X.shape[3] - 1 | ||
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2) | ||
v = d.sum(dim=[2, 3], keepdim=True) / n | ||
E_inv = d / (4 * (v + lambda_p)) + 0.5 | ||
return X * self.sigmoid(E_inv) | ||
|
||
|
||
class ResNet(nn.Module): | ||
def __init__( | ||
self, in_planes, block, num_blocks, in_ch=1, **kwargs | ||
): | ||
super(ResNet, self).__init__() | ||
self.in_planes = in_planes | ||
self.NormLayer = nn.BatchNorm2d | ||
self.ConvLayer = nn.Conv2d | ||
|
||
self.conv1 = self.ConvLayer( | ||
in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False | ||
) | ||
self.bn1 = self.NormLayer(in_planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.layer1 = self._make_layer( | ||
block, in_planes, num_blocks[0], stride=1, block_id=1 | ||
) | ||
self.layer2 = self._make_layer( | ||
block, in_planes * 2, num_blocks[1], stride=2, block_id=2 | ||
) | ||
self.layer3 = self._make_layer( | ||
block, in_planes * 4, num_blocks[2], stride=2, block_id=3 | ||
) | ||
self.layer4 = self._make_layer( | ||
block, in_planes * 8, num_blocks[3], stride=2, block_id=4 | ||
) | ||
|
||
def _make_layer(self, block, planes, num_blocks, stride, block_id=1): | ||
strides = [stride] + [1] * (num_blocks - 1) | ||
layers = [] | ||
for stride in strides: | ||
layers.append( | ||
block( | ||
self.ConvLayer, | ||
self.NormLayer, | ||
self.in_planes, | ||
planes, | ||
stride, | ||
block_id, | ||
) | ||
) | ||
self.in_planes = planes * block.expansion | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.relu(self.bn1(self.conv1(x))) | ||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
return x | ||
|
||
|
||
def SimAM_ResNet34(in_planes): | ||
return ResNet(in_planes, SimAMBasicBlock, [3, 4, 6, 3]) | ||
|
||
|
||
def SimAM_ResNet100(in_planes): | ||
return ResNet(in_planes, SimAMBasicBlock, [6, 16, 24, 3]) | ||
|
||
|
||
class SimAM_ResNet34_ASP(nn.Module): | ||
def __init__(self, in_planes=64, embed_dim=256, acoustic_dim=80, dropout=0): | ||
super(SimAM_ResNet34_ASP, self).__init__() | ||
self.front = SimAM_ResNet34(in_planes) | ||
self.pooling = pooling_layers.ASP(in_planes, acoustic_dim) | ||
self.bottleneck = nn.Linear(self.pooling.out_dim, embed_dim) | ||
self.drop = nn.Dropout(dropout) if dropout else None | ||
|
||
def forward(self, x): | ||
x = x.permute(0, 2, 1) | ||
x = self.front(x.unsqueeze(dim=1)) | ||
x = self.pooling(x) | ||
if self.drop: | ||
x = self.drop(x) | ||
x = self.bottleneck(x) | ||
return x | ||
|
||
|
||
class SimAM_ResNet100_ASP(nn.Module): | ||
def __init__(self, in_planes=64, embed_dim=256, acoustic_dim=80, dropout=0): | ||
super(SimAM_ResNet100_ASP, self).__init__() | ||
self.front = SimAM_ResNet100(in_planes) | ||
self.pooling = pooling_layers.ASP(in_planes, acoustic_dim) | ||
self.bottleneck = nn.Linear(self.pooling.out_dim, embed_dim) | ||
self.drop = nn.Dropout(dropout) if dropout else None | ||
|
||
def forward(self, x): | ||
x = x.permute(0, 2, 1) | ||
x = self.front(x.unsqueeze(dim=1)) | ||
x = self.pooling(x) | ||
if self.drop: | ||
x = self.drop(x) | ||
x = self.bottleneck(x) | ||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.zeros(1, 200, 80) | ||
model = SimAM_ResNet34_ASP(embed_dim=256) | ||
model.eval() | ||
out = model(x) | ||
print(out[-1].size()) | ||
|
||
num_params = sum(p.numel() for p in model.parameters()) | ||
print("{} M".format(num_params / 1e6)) |
Oops, something went wrong.