diff --git a/src/torch_wae/network.py b/src/torch_wae/network.py index 9e230da..16b93bf 100644 --- a/src/torch_wae/network.py +++ b/src/torch_wae/network.py @@ -15,10 +15,14 @@ def __init__(self, s: int) -> None: self.preprocess = Preprocess() self.encoder = Encoder(s=s) + self.head = WAEConvHead(s=s) + self.norm = L2Normalize() def forward(self, waveform: torch.Tensor) -> torch.Tensor: - x = self.preprocess(waveform) - z = self.encoder(x) + h = self.preprocess(waveform) + h = self.encoder(h) + h = self.head(h) + z = self.norm(h) return z @@ -60,6 +64,17 @@ def __init__(self, s: int) -> None: # -------------------- InvertedBottleneck(k=3, c_in=32 * s, c_out=64 * s, stride=1), InvertedBottleneck(k=3, c_in=64 * s, c_out=64 * s, stride=1), + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: + return self.layers(x) + + +class WAEConvHead(nn.Module): + def __init__(self, s: int) -> None: + super().__init__() + + self.layers = nn.Sequential( # -------------------- # shape: (64, 4, 4) -> (64, 1, 1) # -------------------- @@ -67,15 +82,11 @@ def __init__(self, s: int) -> None: nn.BatchNorm2d(64 * s), nn.LeakyReLU(), nn.Conv2d(64 * s, 64 * s, 1, stride=1), - # -------------------- - # normalize - # -------------------- nn.Flatten(), - L2Normalize(), ) - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - return self.layers(x) + def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]: + return self.layers(h) class LightUNet(nn.Module): diff --git a/tests/test__network.py b/tests/test__network.py index 797876e..cd8ebe0 100644 --- a/tests/test__network.py +++ b/tests/test__network.py @@ -2,7 +2,7 @@ import torch -from torch_wae.network import LightUNet, Preprocess, WAENet +from torch_wae.network import LightUNet, Preprocess, WAEConvHead, WAENet def test__preprocess_shape() -> None: @@ -28,6 +28,15 @@ def test__wae_encoder_shape() -> None: f = WAENet(s=s).train(False) x = torch.randn((1, 1, d, d)) z = f.encoder(x) + assert z.shape == (1, d * s, 4, 4) + + +def test__wae_conv_head_shape() -> None: + for s in (1, 2): + d = 64 + f = WAEConvHead(s=s).train(False) + x = torch.randn((1, d * s, 4, 4)) + z = f(x) assert z.shape == (1, d * s)