Skip to content

Commit

Permalink
Normalize log-melspectrogram.
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Dec 14, 2024
1 parent 59e424d commit c8e7a1e
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/torch_wae/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
class Preprocess(nn.Module):
SAMPLE_RATE: int = 16000

def __init__(self) -> None:
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()

self.melSpec = Spectrogram(
Expand All @@ -335,12 +335,17 @@ def __init__(self) -> None:
power=2,
)
self.melSpec.set_mode("DFT", "on_the_fly")
self.eps = eps

with torch.no_grad():
x = torch.zeros((1, self.SAMPLE_RATE))
x = self.melSpec(x)
self.log_mel_min = torch.log(x + self.eps).min()

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
x = self.melSpec(waveform)
x = torch.clip(x, 1e-6)
x = torch.log(x)
x = x[:, None, :, :]
x = torch.clip(torch.log(x + self.eps) - self.log_mel_min, 0.0)
x = x.unsqueeze(1)
return x


Expand Down

0 comments on commit c8e7a1e

Please sign in to comment.