Skip to content

Commit

Permalink
Added various head types: 'conv', 'linear', 'atten_1', 'atten_2', 'at…
Browse files Browse the repository at this point in the history
…ten_4'
  • Loading branch information
mitsuse committed Dec 4, 2024
1 parent 6d9aaf3 commit 622c3da
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 18 deletions.
18 changes: 15 additions & 3 deletions src/torch_wae/cli/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import typer

from torch_wae.network import WAENet, WithResample
from torch_wae.network import WAEActivationType, WAEHeadType, WAENet, WithResample

app = typer.Typer()

Expand All @@ -20,12 +20,24 @@ def main(
48000,
help="the original sample-rate for input (NOTE: WAENet resamples audio to 16KHz)",
),
head_type: WAEHeadType = typer.Option(
...,
help="",
),
head_activation_type: WAEActivationType = typer.Option(
...,
help="",
),
output: Path = typer.Option(
...,
help="the output path of a model converted for ONNX",
),
) -> None:
f = WAENet(s=1)
f = WAENet(
head_type=head_type,
head_activation_type=head_activation_type,
s=1,
)
assert sample_rate >= f.preprocess.SAMPLE_RATE

f.load_state_dict(torch.load(pt))
Expand All @@ -38,7 +50,7 @@ def main(
output.parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(
m,
waveform,
(waveform,),
str(output),
export_params=True,
input_names=["waveform"],
Expand Down
150 changes: 142 additions & 8 deletions src/torch_wae/network.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,83 @@
from __future__ import annotations

import logging
from enum import Enum

import torch
from convmelspec.stft import ConvertibleSpectrogram as Spectrogram
from torch import nn
from torch.nn import functional as F
from torchaudio import functional as FA


class WAEHeadType(str, Enum):
CONV = "conv"
LINEAR = "linear"
ATTEN_1 = "atten_1"
ATTEN_2 = "atten_2"
ATTEN_4 = "atten_4"


class WAEActivationType(str, Enum):
LEAKY_RELU = "leaky_relu"
TANH = "tanh"


# Wowrd Audio Encoder - A network for audio similar to MobileNet V2 for images.
class WAENet(nn.Module):
def __init__(self, s: int) -> None:
def __init__(
self,
s: int,
head_type: WAEHeadType,
head_activation_type: WAEActivationType,
) -> None:
super().__init__()

self.preprocess = Preprocess()

self.encoder = Encoder(s=s)

if head_type not in (WAEHeadType.CONV, WAEHeadType.LINEAR):
logging.debug(
"`head_activation_type` is not supported with specified `head_type`"
)

match head_type:
case WAEHeadType.CONV:
self.head: nn.Module = WAEConvHead(
activation_type=head_activation_type,
s=s,
)
case WAEHeadType.LINEAR:
self.head = WAELinearHead(
activation_type=head_activation_type,
s=s,
)
case WAEHeadType.ATTEN_1:
self.head = WAEAttentionHead(
n_head=1,
s=s,
)
case WAEHeadType.ATTEN_2:
self.head = WAEAttentionHead(
n_head=2,
s=s,
)
case WAEHeadType.ATTEN_4:
self.head = WAEAttentionHead(
n_head=4,
s=s,
)
case _:
raise ValueError(f"unknown head type: {head_type}")

self.norm = L2Normalize()

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
x = self.preprocess(waveform)
z = self.encoder(x)
h = self.preprocess(waveform)
h = self.encoder(h)
h = self.head(h)
z = self.norm(h)
return z


Expand Down Expand Up @@ -60,22 +119,97 @@ def __init__(self, s: int) -> None:
# --------------------
InvertedBottleneck(k=3, c_in=32 * s, c_out=64 * s, stride=1),
InvertedBottleneck(k=3, c_in=64 * s, c_out=64 * s, stride=1),
)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
return self.layers(x)


class WAEConvHead(nn.Module):
def __init__(
self,
activation_type: WAEActivationType,
s: int,
) -> None:
super().__init__()

match activation_type:
case WAEActivationType.LEAKY_RELU:
activation: nn.Module = nn.LeakyReLU()
case WAEActivationType.TANH:
activation = nn.Tanh()
case _:
raise ValueError(f"unknown activation type: {activation_type}")

self.layers = nn.Sequential(
# --------------------
# shape: (64, 4, 4) -> (64, 1, 1)
# --------------------
nn.Conv2d(64 * s, 64 * s, 4, stride=1),
nn.BatchNorm2d(64 * s),
nn.LeakyReLU(),
activation,
nn.Conv2d(64 * s, 64 * s, 1, stride=1),
nn.Flatten(),
)

def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]:
return self.layers(h)


class WAELinearHead(nn.Module):
def __init__(
self,
activation_type: WAEActivationType,
s: int,
) -> None:
super().__init__()

match activation_type:
case WAEActivationType.LEAKY_RELU:
activation: nn.Module = nn.LeakyReLU()
case WAEActivationType.TANH:
activation = nn.Tanh()
case _:
raise ValueError(f"unknown activation type: {activation_type}")

self.layers = nn.Sequential(
# --------------------
# normalize
# shape: (64, 4, 4) -> (1024,)
# --------------------
nn.Flatten(),
L2Normalize(),
# --------------------
# shape: (1024,) -> (64,)
# --------------------
nn.Linear(1024 * s, 256 * s),
nn.BatchNorm1d(256 * s),
activation,
nn.Linear(256 * s, 64 * s),
)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x)
def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]:
return self.layers(h)


class WAEAttentionHead(nn.Module):
def __init__(
self,
n_head: int,
s: int,
) -> None:
super().__init__()

self.attention = nn.MultiheadAttention(
embed_dim=64 * s,
num_heads=n_head,
batch_first=True,
)

def forward(self, h: torch.Tensor) -> tuple[torch.Tensor]:
batch_size, d_s, height, width = h.shape
h = h.permute(0, 2, 3, 1).reshape(batch_size, height * width, d_s)
h, _ = self.attention(h, h, h)
z = h.mean(dim=1) # shape: (batch_size, d * s)
return z


class LightUNet(nn.Module):
Expand Down
73 changes: 66 additions & 7 deletions tests/test__network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

import torch

from torch_wae.network import LightUNet, Preprocess, WAENet
from torch_wae.network import (
Encoder,
LightUNet,
Preprocess,
WAEActivationType,
WAEAttentionHead,
WAEConvHead,
WAEHeadType,
WAELinearHead,
WAENet,
)


def test__preprocess_shape() -> None:
Expand All @@ -25,16 +35,65 @@ def test__light_unet_shape() -> None:
def test__wae_encoder_shape() -> None:
for s in (1, 2):
d = 64
f = WAENet(s=s).train(False)
f = Encoder(s=s).train(False)
x = torch.randn((1, 1, d, d))
z = f.encoder(x)
z = f(x)
assert z.shape == (1, d * s, 4, 4)


def test__wae_conv_head_shape() -> None:
for s in (1, 2):
d = 64
f = WAEConvHead(
activation_type=WAEActivationType.LEAKY_RELU,
s=s,
).train(False)
x = torch.randn((1, d * s, 4, 4))
z = f(x)
assert z.shape == (1, d * s)


def test__wae_forward_shape() -> None:
def test__wae_linear_head_shape() -> None:
for s in (1, 2):
d = 64
f = WAELinearHead(
activation_type=WAEActivationType.LEAKY_RELU,
s=s,
).train(False)
x = torch.randn((1, d * s, 4, 4))
z = f(x)
assert z.shape == (1, d * s)


def test__wae_attention_head_shape() -> None:
for s in (1, 2):
d = 64
f = WAENet(s=s).train(False)
waveform = torch.randn((1, f.preprocess.SAMPLE_RATE))
z = f(waveform)
f = WAEAttentionHead(
n_head=2,
s=s,
).train(False)
x = torch.randn((1, d * s, 4, 4))
z = f(x)
assert z.shape == (1, d * s)


def test__wae_forward_shape() -> None:
seq_head = (
WAEHeadType.CONV,
WAEHeadType.LINEAR,
WAEHeadType.ATTEN_1,
WAEHeadType.ATTEN_2,
)
seq_s = (1, 2)

for head_type in seq_head:
for s in seq_s:
d = 64
f = WAENet(
head_type=head_type,
head_activation_type=WAEActivationType.LEAKY_RELU,
s=s,
).train(False)
waveform = torch.randn((1, f.preprocess.SAMPLE_RATE))
z = f(waveform)
assert z.shape == (1, d * s)

0 comments on commit 622c3da

Please sign in to comment.