diff --git a/src/torch_wae/cli/export_to_onnx.py b/src/torch_wae/cli/export_to_onnx.py index c1bbdc3..67964ac 100644 --- a/src/torch_wae/cli/export_to_onnx.py +++ b/src/torch_wae/cli/export_to_onnx.py @@ -5,7 +5,7 @@ import torch import typer -from torch_wae.network import WAENet, WithResample +from torch_wae.network import WAEActivationType, WAEHeadType, WAENet, WithResample app = typer.Typer() @@ -20,12 +20,24 @@ def main( 48000, help="the original sample-rate for input (NOTE: WAENet resamples audio to 16KHz)", ), + head_type: WAEHeadType = typer.Option( + ..., + help="", + ), + head_activation_type: WAEActivationType = typer.Option( + ..., + help="", + ), output: Path = typer.Option( ..., help="the output path of a model converted for ONNX", ), ) -> None: - f = WAENet(s=1) + f = WAENet( + head_type=head_type, + head_activation_type=head_activation_type, + s=1, + ) assert sample_rate >= f.preprocess.SAMPLE_RATE f.load_state_dict(torch.load(pt)) @@ -38,7 +50,7 @@ def main( output.parent.mkdir(parents=True, exist_ok=True) torch.onnx.export( m, - waveform, + (waveform,), str(output), export_params=True, input_names=["waveform"], diff --git a/src/torch_wae/network.py b/src/torch_wae/network.py index 9e230da..fb448bb 100644 --- a/src/torch_wae/network.py +++ b/src/torch_wae/network.py @@ -1,5 +1,8 @@ from __future__ import annotations +import logging +from enum import Enum + import torch from convmelspec.stft import ConvertibleSpectrogram as Spectrogram from torch import nn @@ -7,18 +10,74 @@ from torchaudio import functional as FA +class WAEHeadType(str, Enum): + CONV = "conv" + LINEAR = "linear" + ATTEN_1 = "atten_1" + ATTEN_2 = "atten_2" + ATTEN_4 = "atten_4" + + +class WAEActivationType(str, Enum): + LEAKY_RELU = "leaky_relu" + TANH = "tanh" + + # Wowrd Audio Encoder - A network for audio similar to MobileNet V2 for images. class WAENet(nn.Module): - def __init__(self, s: int) -> None: + def __init__( + self, + s: int, + head_type: WAEHeadType, + head_activation_type: WAEActivationType, + ) -> None: super().__init__() self.preprocess = Preprocess() self.encoder = Encoder(s=s) + if head_type not in (WAEHeadType.CONV, WAEHeadType.LINEAR): + logging.debug( + "`head_activation_type` is not supported with specified `head_type`" + ) + + match head_type: + case WAEHeadType.CONV: + self.head: nn.Module = WAEConvHead( + activation_type=head_activation_type, + s=s, + ) + case WAEHeadType.LINEAR: + self.head = WAELinearHead( + activation_type=head_activation_type, + s=s, + ) + case WAEHeadType.ATTEN_1: + self.head = WAEAttentionHead( + n_head=1, + s=s, + ) + case WAEHeadType.ATTEN_2: + self.head = WAEAttentionHead( + n_head=2, + s=s, + ) + case WAEHeadType.ATTEN_4: + self.head = WAEAttentionHead( + n_head=4, + s=s, + ) + case _: + raise ValueError(f"unknown head type: {head_type}") + + self.norm = L2Normalize() + def forward(self, waveform: torch.Tensor) -> torch.Tensor: - x = self.preprocess(waveform) - z = self.encoder(x) + h = self.preprocess(waveform) + h = self.encoder(h) + h = self.head(h) + z = self.norm(h) return z @@ -60,22 +119,97 @@ def __init__(self, s: int) -> None: # -------------------- InvertedBottleneck(k=3, c_in=32 * s, c_out=64 * s, stride=1), InvertedBottleneck(k=3, c_in=64 * s, c_out=64 * s, stride=1), + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: + return self.layers(x) + + +class WAEConvHead(nn.Module): + def __init__( + self, + activation_type: WAEActivationType, + s: int, + ) -> None: + super().__init__() + + match activation_type: + case WAEActivationType.LEAKY_RELU: + activation: nn.Module = nn.LeakyReLU() + case WAEActivationType.TANH: + activation = nn.Tanh() + case _: + raise ValueError(f"unknown activation type: {activation_type}") + + self.layers = nn.Sequential( # -------------------- # shape: (64, 4, 4) -> (64, 1, 1) # -------------------- nn.Conv2d(64 * s, 64 * s, 4, stride=1), nn.BatchNorm2d(64 * s), - nn.LeakyReLU(), + activation, nn.Conv2d(64 * s, 64 * s, 1, stride=1), + nn.Flatten(), + ) + + def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]: + return self.layers(h) + + +class WAELinearHead(nn.Module): + def __init__( + self, + activation_type: WAEActivationType, + s: int, + ) -> None: + super().__init__() + + match activation_type: + case WAEActivationType.LEAKY_RELU: + activation: nn.Module = nn.LeakyReLU() + case WAEActivationType.TANH: + activation = nn.Tanh() + case _: + raise ValueError(f"unknown activation type: {activation_type}") + + self.layers = nn.Sequential( # -------------------- - # normalize + # shape: (64, 4, 4) -> (1024,) # -------------------- nn.Flatten(), - L2Normalize(), + # -------------------- + # shape: (1024,) -> (64,) + # -------------------- + nn.Linear(1024 * s, 256 * s), + nn.BatchNorm1d(256 * s), + activation, + nn.Linear(256 * s, 64 * s), ) - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - return self.layers(x) + def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]: + return self.layers(h) + + +class WAEAttentionHead(nn.Module): + def __init__( + self, + n_head: int, + s: int, + ) -> None: + super().__init__() + + self.attention = nn.MultiheadAttention( + embed_dim=64 * s, + num_heads=n_head, + batch_first=True, + ) + + def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]: + batch_size, d_s, height, width = h.shape + h = h.permute(0, 2, 3, 1).reshape(batch_size, height * width, d_s) + h, _ = self.attention(h, h, h) + z = h.mean(dim=1) # shape: (batch_size, d * s) + return z class LightUNet(nn.Module): diff --git a/tests/test__network.py b/tests/test__network.py index 797876e..ae94f15 100644 --- a/tests/test__network.py +++ b/tests/test__network.py @@ -2,7 +2,17 @@ import torch -from torch_wae.network import LightUNet, Preprocess, WAENet +from torch_wae.network import ( + Encoder, + LightUNet, + Preprocess, + WAEActivationType, + WAEAttentionHead, + WAEConvHead, + WAEHeadType, + WAELinearHead, + WAENet, +) def test__preprocess_shape() -> None: @@ -25,16 +35,65 @@ def test__light_unet_shape() -> None: def test__wae_encoder_shape() -> None: for s in (1, 2): d = 64 - f = WAENet(s=s).train(False) + f = Encoder(s=s).train(False) x = torch.randn((1, 1, d, d)) - z = f.encoder(x) + z = f(x) + assert z.shape == (1, d * s, 4, 4) + + +def test__wae_conv_head_shape() -> None: + for s in (1, 2): + d = 64 + f = WAEConvHead( + activation_type=WAEActivationType.LEAKY_RELU, + s=s, + ).train(False) + x = torch.randn((1, d * s, 4, 4)) + z = f(x) assert z.shape == (1, d * s) -def test__wae_forward_shape() -> None: +def test__wae_linear_head_shape() -> None: + for s in (1, 2): + d = 64 + f = WAELinearHead( + activation_type=WAEActivationType.LEAKY_RELU, + s=s, + ).train(False) + x = torch.randn((1, d * s, 4, 4)) + z = f(x) + assert z.shape == (1, d * s) + + +def test__wae_attention_head_shape() -> None: for s in (1, 2): d = 64 - f = WAENet(s=s).train(False) - waveform = torch.randn((1, f.preprocess.SAMPLE_RATE)) - z = f(waveform) + f = WAEAttentionHead( + n_head=2, + s=s, + ).train(False) + x = torch.randn((1, d * s, 4, 4)) + z = f(x) assert z.shape == (1, d * s) + + +def test__wae_forward_shape() -> None: + seq_head = ( + WAEHeadType.CONV, + WAEHeadType.LINEAR, + WAEHeadType.ATTEN_1, + WAEHeadType.ATTEN_2, + ) + seq_s = (1, 2) + + for head_type in seq_head: + for s in seq_s: + d = 64 + f = WAENet( + head_type=head_type, + head_activation_type=WAEActivationType.LEAKY_RELU, + s=s, + ).train(False) + waveform = torch.randn((1, f.preprocess.SAMPLE_RATE)) + z = f(waveform) + assert z.shape == (1, d * s)