Skip to content

Commit

Permalink
Move WithResample layer to torch_wae.network.
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Sep 15, 2024
1 parent 42ea029 commit 85921ed
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
15 changes: 1 addition & 14 deletions src/torch_wae/cli/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typer
from torchaudio import functional as F

from torch_wae.network import WAENet
from torch_wae.network import WAENet, WithResample

app = typer.Typer()

Expand Down Expand Up @@ -51,18 +51,5 @@ def main(
)


class WithResample(torch.nn.Module):
def __init__(self, f: WAENet, sample_rate: int) -> None:
super().__init__()

self.f = f
self.sample_rate = sample_rate

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
h = F.resample(waveform, self.sample_rate, self.f.SAMPLE_RATE)
z = self.f(h)
return z


if __name__ == "__main__":
app()
13 changes: 13 additions & 0 deletions src/torch_wae/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,16 @@ def __init__(self, dim: int = 1, eps: float = 1e-12):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.normalize(x, p=2, dim=self.dim, eps=self.eps)


class WithResample(torch.nn.Module):
def __init__(self, f: WAENet, sample_rate: int) -> None:
super().__init__()

self.f = f
self.sample_rate = sample_rate

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
h = F.resample(waveform, self.sample_rate, self.f.SAMPLE_RATE)
z = self.f(h)
return z

0 comments on commit 85921ed

Please sign in to comment.