Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Dec 4, 2024
1 parent 6d9aaf3 commit a23aa8f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
27 changes: 19 additions & 8 deletions src/torch_wae/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ def __init__(self, s: int) -> None:
self.preprocess = Preprocess()

self.encoder = Encoder(s=s)
self.head = WAEConvHead(s=s)
self.norm = L2Normalize()

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
x = self.preprocess(waveform)
z = self.encoder(x)
h = self.preprocess(waveform)
h = self.encoder(h)
h = self.head(h)
z = self.norm(h)
return z


Expand Down Expand Up @@ -60,22 +64,29 @@ def __init__(self, s: int) -> None:
# --------------------
InvertedBottleneck(k=3, c_in=32 * s, c_out=64 * s, stride=1),
InvertedBottleneck(k=3, c_in=64 * s, c_out=64 * s, stride=1),
)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
return self.layers(x)


class WAEConvHead(nn.Module):
def __init__(self, s: int) -> None:
super().__init__()

self.layers = nn.Sequential(
# --------------------
# shape: (64, 4, 4) -> (64, 1, 1)
# --------------------
nn.Conv2d(64 * s, 64 * s, 4, stride=1),
nn.BatchNorm2d(64 * s),
nn.LeakyReLU(),
nn.Conv2d(64 * s, 64 * s, 1, stride=1),
# --------------------
# normalize
# --------------------
nn.Flatten(),
L2Normalize(),
)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x)
def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]:
return self.layers(h)


class LightUNet(nn.Module):
Expand Down
11 changes: 10 additions & 1 deletion tests/test__network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torch_wae.network import LightUNet, Preprocess, WAENet
from torch_wae.network import LightUNet, Preprocess, WAEConvHead, WAENet


def test__preprocess_shape() -> None:
Expand All @@ -28,6 +28,15 @@ def test__wae_encoder_shape() -> None:
f = WAENet(s=s).train(False)
x = torch.randn((1, 1, d, d))
z = f.encoder(x)
assert z.shape == (1, d * s, 4, 4)


def test__wae_conv_head_shape() -> None:
for s in (1, 2):
d = 64
f = WAEConvHead(s=s).train(False)
x = torch.randn((1, d * s, 4, 4))
z = f(x)
assert z.shape == (1, d * s)


Expand Down

0 comments on commit a23aa8f

Please sign in to comment.