From 0793f3bafa42c96a18e0e70ccd4be63030a20044 Mon Sep 17 00:00:00 2001 From: zhengyang Date: Wed, 25 Sep 2024 19:24:48 +0800 Subject: [PATCH] [models] update frame-level feature extraction interface --- wespeaker/models/ecapa_tdnn.py | 6 +++--- wespeaker/models/eres2net.py | 6 +++--- wespeaker/models/gemini_dfresnet.py | 6 +++--- wespeaker/models/redimnet.py | 6 +++--- wespeaker/models/repvgg.py | 6 +++--- wespeaker/models/res2net.py | 6 +++--- wespeaker/models/resnet.py | 6 +++--- wespeaker/models/tdnn.py | 6 +++--- 8 files changed, 24 insertions(+), 24 deletions(-) diff --git a/wespeaker/models/ecapa_tdnn.py b/wespeaker/models/ecapa_tdnn.py index d02d4a01..99824f9d 100644 --- a/wespeaker/models/ecapa_tdnn.py +++ b/wespeaker/models/ecapa_tdnn.py @@ -205,7 +205,7 @@ def __init__(self, else: self.bn2 = nn.Identity() - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) @@ -221,11 +221,11 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x).permute(0, 2, 1) + out = self._get_frame_level_feat(x).permute(0, 2, 1) return out # (B, T, D) def forward(self, x): - out = F.relu(self.__get_frame_level_feat(x)) + out = F.relu(self._get_frame_level_feat(x)) out = self.bn(self.pool(out)) out = self.linear(out) if self.emb_bn: diff --git a/wespeaker/models/eres2net.py b/wespeaker/models/eres2net.py index 41a26fdb..ea02a0c6 100644 --- a/wespeaker/models/eres2net.py +++ b/wespeaker/models/eres2net.py @@ -351,7 +351,7 @@ def _make_layer(self, self.in_planes = planes * self.expansion return nn.Sequential(*layers) - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.unsqueeze_(1) @@ -371,14 +371,14 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) out = out.transpose(1, 3) out = torch.flatten(out, 2, -1) return out # (B, T, D) def forward(self, x): - fuse_out1234 = self.__get_frame_level_feat(x) + fuse_out1234 = self._get_frame_level_feat(x) stats = self.pool(fuse_out1234) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/gemini_dfresnet.py b/wespeaker/models/gemini_dfresnet.py index f9ebeb6e..52f36630 100644 --- a/wespeaker/models/gemini_dfresnet.py +++ b/wespeaker/models/gemini_dfresnet.py @@ -102,7 +102,7 @@ def __init__(self, self.seg_bn_1 = nn.Identity() self.seg_2 = nn.Identity() - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.unsqueeze_(1) @@ -120,7 +120,7 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) out = out.transpose(1, 3) out = torch.flatten(out, 2, -1) @@ -128,7 +128,7 @@ def get_frame_level_feat(self, x): def forward(self, x): - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py index 6ded935e..bc5434aa 100644 --- a/wespeaker/models/redimnet.py +++ b/wespeaker/models/redimnet.py @@ -843,7 +843,7 @@ def __init__( self.seg_bn_1 = nn.Identity() self.seg_2 = nn.Identity() - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,F,T) => (B,T,F) x = x.unsqueeze_(1) @@ -853,12 +853,12 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x).permute(0, 2, 1) + out = self._get_frame_level_feat(x).permute(0, 2, 1) return out # (B, T, D) def forward(self, x): - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/repvgg.py b/wespeaker/models/repvgg.py index 1827303a..b5e3298f 100644 --- a/wespeaker/models/repvgg.py +++ b/wespeaker/models/repvgg.py @@ -559,7 +559,7 @@ def get_downsample_multiple(self): def get_output_planes(self): return self.output_planes - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) x = x.unsqueeze_(1) @@ -573,14 +573,14 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) out = out.transpose(1, 3) out = torch.flatten(out, 2, -1) return out # (B, T, D) def forward(self, x): - x = self.__get_frame_level_feat(x) + x = self._get_frame_level_feat(x) stats = self.pool(x) embed = self.seg(stats) diff --git a/wespeaker/models/res2net.py b/wespeaker/models/res2net.py index d320d8b7..c7833963 100644 --- a/wespeaker/models/res2net.py +++ b/wespeaker/models/res2net.py @@ -154,7 +154,7 @@ def _make_layer(self, block, planes, num_blocks, stride): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) @@ -169,14 +169,14 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) out = out.transpose(1, 3) out = torch.flatten(out, 2, -1) return out # (B, T, D) def forward(self, x): - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) diff --git a/wespeaker/models/resnet.py b/wespeaker/models/resnet.py index 3fb51186..13df0a23 100644 --- a/wespeaker/models/resnet.py +++ b/wespeaker/models/resnet.py @@ -168,7 +168,7 @@ def _make_layer(self, block, planes, num_blocks, stride): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) @@ -183,14 +183,14 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) out = out.transpose(1, 3) out = torch.flatten(out, 2, -1) return out # (B, T, D) def forward(self, x): - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) stats = self.pool(out) diff --git a/wespeaker/models/tdnn.py b/wespeaker/models/tdnn.py index 398bbc3b..59724437 100644 --- a/wespeaker/models/tdnn.py +++ b/wespeaker/models/tdnn.py @@ -86,7 +86,7 @@ def __init__(self, self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) self.seg_2 = nn.Linear(embed_dim, embed_dim) - def __get_frame_level_feat(self, x): + def _get_frame_level_feat(self, x): # for inner class usage x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) @@ -100,12 +100,12 @@ def __get_frame_level_feat(self, x): def get_frame_level_feat(self, x): # for outer interface - out = self.__get_frame_level_feat(x).permute(0, 2, 1) + out = self._get_frame_level_feat(x).permute(0, 2, 1) return out # (B, T, D) def forward(self, x): - out = self.__get_frame_level_feat(x) + out = self._get_frame_level_feat(x) stats = self.pool(out) embed_a = self.seg_1(stats) out = F.relu(embed_a)