diff --git a/wespeaker/models/campplus.py b/wespeaker/models/campplus.py index 5d6c7cfb..8effcadc 100644 --- a/wespeaker/models/campplus.py +++ b/wespeaker/models/campplus.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 Hongji Wang (jijijiang77@gmail.com) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -394,6 +395,17 @@ def __init__(self, if m.bias is not None: nn.init.zeros_(m.bias) + def get_frame_level_feat(self, x): + # for outer interface + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = self.head(x) + for layer in self.xvector[:-2]: + x = layer(x) + + out = x.permute(0, 2, 1) + + return out # (B, T, D) + def forward(self, x): x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = self.head(x) diff --git a/wespeaker/models/ecapa_tdnn.py b/wespeaker/models/ecapa_tdnn.py index cb7db442..d02d4a01 100644 --- a/wespeaker/models/ecapa_tdnn.py +++ b/wespeaker/models/ecapa_tdnn.py @@ -1,6 +1,7 @@ # Copyright (c) 2021 Zhengyang Chen (chenzhengyang117@gmail.com) # 2022 Hongji Wang (jijijiang77@gmail.com) # 2023 Bing Han (hanbing97@sjtu.edu.cn) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -204,7 +205,8 @@ def __init__(self, else: self.bn2 = nn.Identity() - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) out1 = self.layer1(x) @@ -213,7 +215,17 @@ def forward(self, x): out4 = self.layer4(out3) out = torch.cat([out2, out3, out4], dim=1) - out = F.relu(self.conv(out)) + out = self.conv(out) + + return out + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x).permute(0, 2, 1) + return out # (B, T, D) + + def forward(self, x): + out = F.relu(self.__get_frame_level_feat(x)) out = self.bn(self.pool(out)) out = self.linear(out) if self.emb_bn: diff --git a/wespeaker/models/eres2net.py b/wespeaker/models/eres2net.py index 9e56219a..41a26fdb 100644 --- a/wespeaker/models/eres2net.py +++ b/wespeaker/models/eres2net.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Hongji Wang (jijijiang77@gmail.com) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -350,7 +351,8 @@ def _make_layer(self, self.in_planes = planes * self.expansion return nn.Sequential(*layers) - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.unsqueeze_(1) out = F.relu(self.bn1(self.conv1(x))) @@ -364,6 +366,19 @@ def forward(self, x): out4 = self.layer4(out3) fuse_out123_downsample = self.layer3_downsample(fuse_out123) fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + + return fuse_out1234 + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x) + out = out.transpose(1, 3) + out = torch.flatten(out, 2, -1) + + return out # (B, T, D) + + def forward(self, x): + fuse_out1234 = self.__get_frame_level_feat(x) stats = self.pool(fuse_out1234) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/gemini_dfresnet.py b/wespeaker/models/gemini_dfresnet.py index 61918369..f9ebeb6e 100644 --- a/wespeaker/models/gemini_dfresnet.py +++ b/wespeaker/models/gemini_dfresnet.py @@ -1,5 +1,6 @@ # Copyright (c) 2024 Shuai Wang (wsstriving@gmail.com) # 2024 Tianchi Liu (tianchi_liu@u.nus.edu) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -101,7 +102,8 @@ def __init__(self, self.seg_bn_1 = nn.Identity() self.seg_2 = nn.Identity() - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.unsqueeze_(1) out = self.downsample_layers[0](x) @@ -114,6 +116,19 @@ def forward(self, x): out = self.downsample_layers[4](out) out = self.stages[3](out) + return out + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x) + out = out.transpose(1, 3) + out = torch.flatten(out, 2, -1) + + return out # (B, T, D) + + def forward(self, x): + + out = self.__get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py index ecd7db17..6ded935e 100644 --- a/wespeaker/models/redimnet.py +++ b/wespeaker/models/redimnet.py @@ -1,5 +1,6 @@ # Copyright (c) 2024 https://github.com/IDRnD/ReDimNet # 2024 Shuai Wang (wsstriving@gmail.com) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,13 +39,13 @@ import torch.nn.functional as F import wespeaker.models.pooling_layers as pooling_layers - MaxPoolNd = {1: nn.MaxPool1d, 2: nn.MaxPool2d} ConvNd = {1: nn.Conv1d, 2: nn.Conv2d} BatchNormNd = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d} class to1d(nn.Module): + def forward(self, x): size = x.size() bs, c, f, t = tuple(size) @@ -52,18 +53,11 @@ def forward(self, x): class NewGELUActivation(nn.Module): + def forward(self, input): - return ( - 0.5 - * input - * ( - 1.0 - + torch.tanh( - math.sqrt(2.0 / math.pi) - * (input + 0.044715 * torch.pow(input, 3.0)) - ) - ) - ) + return (0.5 * input * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * + (input + 0.044715 * torch.pow(input, 3.0))))) class LayerNorm(nn.Module): @@ -82,7 +76,7 @@ def __init__(self, C, eps=1e-6, data_format="channels_last"): self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError - self.C = (C,) + self.C = (C, ) def forward(self, x): if self.data_format == "channels_last": @@ -101,19 +95,17 @@ def forward(self, x): return x def extra_repr(self) -> str: - return ", ".join( - [ - f"{k}={v}" - for k, v in { - "C": self.C, - "data_format": self.data_format, - "eps": self.eps, - }.items() - ] - ) + return ", ".join([ + f"{k}={v}" for k, v in { + "C": self.C, + "data_format": self.data_format, + "eps": self.eps, + }.items() + ]) class GRU(nn.Module): + def __init__(self, *args, **kwargs): super(GRU, self).__init__() self.gru = nn.GRU(*args, **kwargs) @@ -124,12 +116,15 @@ def forward(self, x): class PosEncConv(nn.Module): + def __init__(self, C, ks, groups=None): super().__init__() assert ks % 2 == 1 - self.conv = nn.Conv1d( - C, C, ks, padding=ks // 2, groups=C if groups is None else groups - ) + self.conv = nn.Conv1d(C, + C, + ks, + padding=ks // 2, + groups=C if groups is None else groups) self.norm = LayerNorm(C, eps=1e-6, data_format="channels_first") def forward(self, x): @@ -137,27 +132,25 @@ def forward(self, x): class ConvNeXtLikeBlock(nn.Module): + def __init__( - self, - C, - dim=2, - kernel_sizes=((3, 3),), - group_divisor=1, - padding="same", + self, + C, + dim=2, + kernel_sizes=((3, 3), ), + group_divisor=1, + padding="same", ): super().__init__() - self.dwconvs = nn.ModuleList( - modules=[ - ConvNd[dim]( - C, - C, - kernel_size=ks, - padding=padding, - groups=C // group_divisor if group_divisor is not None else 1, - ) - for ks in kernel_sizes - ] - ) + self.dwconvs = nn.ModuleList(modules=[ + ConvNd[dim]( + C, + C, + kernel_size=ks, + padding=padding, + groups=C // group_divisor if group_divisor is not None else 1, + ) for ks in kernel_sizes + ]) self.norm = BatchNormNd[dim](C * len(kernel_sizes)) self.gelu = nn.GELU() self.pwconv1 = ConvNd[dim](C * len(kernel_sizes), C, 1) @@ -172,6 +165,7 @@ def forward(self, x): class ConvBlock2d(nn.Module): + def __init__(self, c, f, block_type="convnext_like", group_divisor=1): super().__init__() if block_type == "convnext_like": @@ -238,11 +232,8 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len, bsz): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) + return (tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous()) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Input shape: Batch x Time x Channel""" @@ -255,18 +246,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: value_states = self._shape(self.v_proj(hidden_states), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = F.softmax(attn_weights, dim=-1) - attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = F.dropout(attn_weights, + p=self.dropout, + training=self.training) attn_output = torch.bmm(attn_probs, value_states) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) @@ -279,6 +274,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class TransformerEncoderLayer(nn.Module): + def __init__( self, n_state, @@ -333,6 +329,7 @@ def forward(self, hidden_states): class FeedForward(nn.Module): + def __init__( self, hidden_size, @@ -378,7 +375,8 @@ def __init__( stride=stride, padding=1, bias=False, - groups=in_planes // group_divisor if group_divisor is not None else 1, + groups=in_planes // + group_divisor if group_divisor is not None else 1, ) # If using group convolution, add point-wise conv to reshape @@ -408,7 +406,11 @@ def __init__( if planes != in_planes: self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.Conv2d(in_planes, + planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(planes), ) else: @@ -457,6 +459,7 @@ def forward(self, inputs): class ResBasicBlock(nn.Module): + def __init__( self, in_planes, @@ -475,7 +478,8 @@ def __init__( stride=stride, padding=1, bias=False, - groups=in_planes // group_divisor if group_divisor is not None else 1, + groups=in_planes // + group_divisor if group_divisor is not None else 1, ) if group_divisor is not None: self.conv1pw = nn.Conv2d(in_planes, planes, 1) @@ -507,7 +511,11 @@ def __init__( if planes != in_planes: self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.Conv2d(in_planes, + planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(planes), ) else: @@ -543,8 +551,8 @@ def __init__( assert pos_ker_sz self.red_dim_conv = nn.Sequential( - nn.Conv1d(C, hC, 1), LayerNorm(hC, eps=1e-6, data_format="channels_first") - ) + nn.Conv1d(C, hC, 1), + LayerNorm(hC, eps=1e-6, data_format="channels_first")) if block_type == "fc": self.tcm = nn.Sequential( @@ -576,18 +584,26 @@ def __init__( elif block_type == "conv+att": # Basic Transformer self-attention encoder block self.tcm = nn.Sequential( - ConvNeXtLikeBlock( - hC, dim=1, kernel_sizes=[7], group_divisor=1, padding="same" - ), - ConvNeXtLikeBlock( - hC, dim=1, kernel_sizes=[19], group_divisor=1, padding="same" - ), - ConvNeXtLikeBlock( - hC, dim=1, kernel_sizes=[31], group_divisor=1, padding="same" - ), - ConvNeXtLikeBlock( - hC, dim=1, kernel_sizes=[59], group_divisor=1, padding="same" - ), + ConvNeXtLikeBlock(hC, + dim=1, + kernel_sizes=[7], + group_divisor=1, + padding="same"), + ConvNeXtLikeBlock(hC, + dim=1, + kernel_sizes=[19], + group_divisor=1, + padding="same"), + ConvNeXtLikeBlock(hC, + dim=1, + kernel_sizes=[31], + group_divisor=1, + padding="same"), + ConvNeXtLikeBlock(hC, + dim=1, + kernel_sizes=[59], + group_divisor=1, + padding="same"), TransformerEncoderLayer(n_state=hC, n_mlp=hC, n_head=4), ) else: @@ -604,6 +620,7 @@ def forward(self, x): class ReDimNetBone(nn.Module): + def __init__( self, F=72, @@ -641,31 +658,30 @@ def build(self, stages_setup, group_divisor, out_channels): # Weighting the inputs # TODO: ask authors about the impact of this pre-weighting self.inputs_weights = torch.nn.ParameterList( - [nn.Parameter(torch.ones(1, 1, 1, 1), requires_grad=False)] - + [ + [nn.Parameter(torch.ones(1, 1, 1, 1), requires_grad=False)] + [ nn.Parameter( torch.zeros(1, num_inputs + 1, self.C * self.F, 1), requires_grad=True, - ) - for num_inputs in range(1, len(stages_setup) + 1) - ] - ) + ) for num_inputs in range(1, + len(stages_setup) + 1) + ]) self.stem = nn.Sequential( nn.Conv2d(1, int(cur_c), kernel_size=3, stride=1, padding="same"), LayerNorm(int(cur_c), eps=1e-6, data_format="channels_first"), ) - Block1d = functools.partial(TimeContextBlock1d, block_type=self.block_1d_type) + Block1d = functools.partial(TimeContextBlock1d, + block_type=self.block_1d_type) Block2d = functools.partial(ConvBlock2d, block_type=self.block_2d_type) self.stages_cfs = [] for stage_ind, ( - stride, - num_blocks, - conv_exp, - kernel_sizes, # TODO: Why the kernel_sizes are not used? - att_block_red, + stride, + num_blocks, + conv_exp, + kernel_sizes, # TODO: Why the kernel_sizes are not used? + att_block_red, ) in enumerate(stages_setup): assert stride in [1, 2, 3] # Pool frequencies & expand channels if needed @@ -689,10 +705,9 @@ def build(self, stages_setup, group_divisor, out_channels): for _ in range(num_blocks): # ConvBlock2d(f, c, block_type="convnext_like", group_divisor=1) layers.append( - Block2d( - c=int(cur_c * conv_exp), f=cur_f, group_divisor=group_divisor - ) - ) + Block2d(c=int(cur_c * conv_exp), + f=cur_f, + group_divisor=group_divisor)) if conv_exp != 1: # Squeeze back channels to align with ReDimNet c+f reshaping: @@ -707,11 +722,8 @@ def build(self, stages_setup, group_divisor, out_channels): kernel_size=(3, 3), stride=1, padding="same", - groups=( - cur_c // _group_divisor - if _group_divisor is not None - else 1 - ), + groups=(cur_c // _group_divisor + if _group_divisor is not None else 1), ), nn.BatchNorm2d( cur_c, @@ -719,22 +731,24 @@ def build(self, stages_setup, group_divisor, out_channels): ), nn.GELU(), nn.Conv2d(cur_c, cur_c, 1), - ) - ) + )) layers.append(to1d()) # reduce block? if att_block_red is not None: layers.append( - Block1d(self.C * self.F, hC=(self.C * self.F) // att_block_red) - ) + Block1d(self.C * self.F, + hC=(self.C * self.F) // att_block_red)) setattr(self, f"stage{stage_ind}", nn.Sequential(*layers)) if out_channels is not None: self.mfa = nn.Sequential( - nn.Conv1d(self.F * self.C, out_channels, kernel_size=1, padding="same"), + nn.Conv1d(self.F * self.C, + out_channels, + kernel_size=1, + padding="same"), nn.BatchNorm1d(out_channels, affine=True), ) else: @@ -776,6 +790,7 @@ def forward(self, inp): class ReDimNet(nn.Module): + def __init__( self, feat_dim=72, @@ -817,8 +832,7 @@ def __init__( out_channels = C * feat_dim self.pool = getattr(pooling_layers, pooling_func)( - in_dim=out_channels, global_context_att=global_context_att - ) + in_dim=out_channels, global_context_att=global_context_att) self.pool_out_dim = self.pool.get_out_dim() self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) @@ -829,12 +843,23 @@ def __init__( self.seg_bn_1 = nn.Identity() self.seg_2 = nn.Identity() - def forward(self, x): - # x = self.spec(x).unsqueeze(1) + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,F,T) => (B,T,F) x = x.unsqueeze_(1) out = self.backbone(x) + return out + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x).permute(0, 2, 1) + + return out # (B, T, D) + + def forward(self, x): + out = self.__get_frame_level_feat(x) + stats = self.pool(out) embed_a = self.seg_1(stats) if self.two_emb_layer: @@ -846,7 +871,10 @@ def forward(self, x): return torch.tensor(0.0), embed_a -def ReDimNetB0(feat_dim=60, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB0(feat_dim=60, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=10, @@ -868,7 +896,10 @@ def ReDimNetB0(feat_dim=60, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa ) -def ReDimNetB1(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB1(feat_dim=72, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=12, @@ -890,7 +921,10 @@ def ReDimNetB1(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa ) -def ReDimNetB2(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB2(feat_dim=72, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=16, @@ -913,7 +947,10 @@ def ReDimNetB2(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa ) -def ReDimNetB3(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB3(feat_dim=72, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=16, @@ -936,7 +973,10 @@ def ReDimNetB3(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa ) -def ReDimNetB4(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB4(feat_dim=72, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=32, @@ -959,7 +999,10 @@ def ReDimNetB4(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa ) -def ReDimNetB5(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB5(feat_dim=72, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=32, @@ -982,7 +1025,10 @@ def ReDimNetB5(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa ) -def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): +def ReDimNetB6(feat_dim=72, + embed_dim=192, + pooling_func="ASTP", + two_emb_layer=False): return ReDimNet( feat_dim=feat_dim, C=32, diff --git a/wespeaker/models/repvgg.py b/wespeaker/models/repvgg.py index cfbeb236..1827303a 100644 --- a/wespeaker/models/repvgg.py +++ b/wespeaker/models/repvgg.py @@ -1,5 +1,6 @@ # Copyright (c) 2021 xmuspeech (Author: Leo) # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -558,7 +559,8 @@ def get_downsample_multiple(self): def get_output_planes(self): return self.output_planes - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) x = x.unsqueeze_(1) x = self.stage0(x) @@ -567,6 +569,18 @@ def forward(self, x): x = self.stage3(x) x = self.stage4(x) + return x + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x) + out = out.transpose(1, 3) + out = torch.flatten(out, 2, -1) + + return out # (B, T, D) + + def forward(self, x): + x = self.__get_frame_level_feat(x) stats = self.pool(x) embed = self.seg(stats) diff --git a/wespeaker/models/res2net.py b/wespeaker/models/res2net.py index 294f191f..d320d8b7 100644 --- a/wespeaker/models/res2net.py +++ b/wespeaker/models/res2net.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Hongji Wang (jijijiang77@gmail.com) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -153,7 +154,8 @@ def _make_layer(self, block, planes, num_blocks, stride): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.unsqueeze_(1) @@ -163,6 +165,18 @@ def forward(self, x): out = self.layer3(out) out = self.layer4(out) + return out + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x) + out = out.transpose(1, 3) + out = torch.flatten(out, 2, -1) + + return out # (B, T, D) + + def forward(self, x): + out = self.__get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/resnet.py b/wespeaker/models/resnet.py index 0e223c21..3fb51186 100644 --- a/wespeaker/models/resnet.py +++ b/wespeaker/models/resnet.py @@ -1,6 +1,7 @@ # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) # 2022 Zhengyang Chen (chenzhengyang117@gmail.com) # 2023 Bing Han (hanbing97@sjtu.edu.cn) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -167,7 +168,8 @@ def _make_layer(self, block, planes, num_blocks, stride): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.unsqueeze_(1) @@ -177,6 +179,19 @@ def forward(self, x): out = self.layer3(out) out = self.layer4(out) + return out + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x) + out = out.transpose(1, 3) + out = torch.flatten(out, 2, -1) + + return out # (B, T, D) + + def forward(self, x): + out = self.__get_frame_level_feat(x) + stats = self.pool(out) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/tdnn.py b/wespeaker/models/tdnn.py index 19c147d7..398bbc3b 100644 --- a/wespeaker/models/tdnn.py +++ b/wespeaker/models/tdnn.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) +# 2024 Zhengyang Chen (chenzhengyang117@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -85,7 +86,8 @@ def __init__(self, self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) self.seg_2 = nn.Linear(embed_dim, embed_dim) - def forward(self, x): + def __get_frame_level_feat(self, x): + # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) out = self.frame_1(x) @@ -94,6 +96,16 @@ def forward(self, x): out = self.frame_4(out) out = self.frame_5(out) + return out + + def get_frame_level_feat(self, x): + # for outer interface + out = self.__get_frame_level_feat(x).permute(0, 2, 1) + + return out # (B, T, D) + + def forward(self, x): + out = self.__get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) out = F.relu(embed_a)