Skip to content

Commit

Permalink
Convert metadata and mel-spec separately.
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Nov 9, 2024
1 parent fe89d55 commit 871afcd
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 124 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ license = {file = "LICENSE"}
convert_dataset_to_jsonl = "torch_wae.cli.convert_dataset_to_jsonl:app"
convert_dataset_to_web = "torch_wae.cli.convert_dataset_to_web:app"
convert_pair_to_web = "torch_wae.cli.convert_pair_to_web:app"
convert_to_melspec = "torch_wae.cli.convert_to_melspec:app"
copy_dataset = "torch_wae.cli.copy_dataset:app"
export_to_onnx = "torch_wae.cli.export_to_onnx:app"
generate_pair = "torch_wae.cli.generate_pair:app"
Expand Down
70 changes: 12 additions & 58 deletions src/torch_wae/cli/convert_dataset_to_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from pathlib import Path
from typing import Any

import torch
import torchaudio
import typer
from torchaudio import functional as FA

app = typer.Typer()

Expand All @@ -26,12 +23,12 @@ def main(
help="the path of a dataset to be encoded.",
),
size_shard: int = typer.Option(
100,
10,
help="the max size of a shard (unit: MB)",
),
output: Path = typer.Option(
...,
help="the output path of a directory which stores shards of WebDataset,",
help="the output path of a directory which stores shards of WebDataset.",
),
) -> None:
from torch.utils.data import DataLoader
Expand All @@ -47,14 +44,7 @@ def main(
n += 1

loader = DataLoader(
transformed_iter(
dataset,
Transform(
root=root,
resample_rate=16000,
durations=1,
),
),
transformed_iter(dataset, Transform()),
batch_size=1,
shuffle=False,
num_workers=n_workers,
Expand All @@ -66,23 +56,17 @@ def main(

with tqdm(total=n) as progress:
with ShardWriter(pattern, maxsize=max_size, verbose=0) as w:
for i, (path, key, class_id, melspec, ignore) in enumerate(loader):
for i, (path, key, class_id) in enumerate(loader):
if n <= i:
break

ignore = bool(ignore.detach().numpy()[0])
if ignore:
continue

path = path[0]
key = key[0]
class_id = int(class_id.detach().numpy()[0])
melspec = melspec.detach().numpy()[0]

w.write(
{
"__key__": key,
"npy": melspec,
"json": {
"class_id": class_id,
"path": path,
Expand All @@ -94,49 +78,19 @@ def main(


class Transform:
def __init__(
self,
root: Path,
resample_rate: int,
durations: int,
) -> None:
from torch_wae.network import Preprocess

def __init__(self) -> None:
super().__init__()

self.__root = root
self.__resample_rate = resample_rate
self.__durations = durations
self.__preprocess = Preprocess()

def __call__(self, example: Any) -> tuple[str, str, int, torch.Tensor, bool]:
from torch_wae.audio import crop_or_pad_last
def __call__(self, example: Any) -> tuple[str, str, int]:
from torch_wae import fs

root = self.__root
resample_rate = self.__resample_rate
durations = self.__durations
preprocess = self.__preprocess

path = str(example["path"])
key = path.rsplit(".", maxsplit=1)[0]
path = Path(example["path"])
basename = fs.basename(path)
path_npy = str(path.parent / f"{basename}.npy")
key = str(path.parent / basename)
class_id = int(example["class_id"])

path_file = root / path

waveform, sample_rate = torchaudio.load(path_file)
waveform = torch.mean(waveform, dim=0).unsqueeze(0)

if resample_rate != sample_rate:
waveform = FA.resample(waveform, sample_rate, resample_rate)

frames = waveform.shape[-1]

waveform = crop_or_pad_last(resample_rate, durations, waveform)
melspec = preprocess(waveform)[0]

ignore = frames > durations * resample_rate

return path, key, class_id, melspec, ignore
return path_npy, key, class_id


if __name__ == "__main__":
Expand Down
77 changes: 11 additions & 66 deletions src/torch_wae/cli/convert_pair_to_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from pathlib import Path
from typing import Any

import torch
import torchaudio
import typer
from torchaudio import functional as FA

app = typer.Typer()

Expand All @@ -26,12 +23,12 @@ def main(
help="the path of a dataset to be encoded.",
),
size_shard: int = typer.Option(
200,
20,
help="the max size of a shard (unit: MB)",
),
output: Path = typer.Option(
...,
help="the output path of a directory which stores shards of WebDataset,",
help="the output path of a directory which stores shards of WebDataset.",
),
) -> None:
from torch.utils.data import DataLoader
Expand All @@ -49,11 +46,7 @@ def main(
loader = DataLoader(
transformed_iter(
dataset,
Transform(
root=root,
resample_rate=16000,
durations=1,
),
Transform(),
),
batch_size=1,
shuffle=False,
Expand All @@ -66,26 +59,18 @@ def main(

with tqdm(total=n) as progress:
with ShardWriter(pattern, maxsize=max_size, verbose=0) as w:
for i, (anchor, positive, key, class_id, melspec, ignore) in enumerate(
loader
):
for i, (anchor, positive, key, class_id) in enumerate(loader):
if n <= i:
break

ignore = bool(ignore.detach().numpy()[0])
if ignore:
continue

anchor = anchor[0]
positive = positive[0]
key = key[0]
class_id = int(class_id.detach().numpy()[0])
melspec = melspec.detach().numpy()[0]

w.write(
{
"__key__": key,
"npy": melspec,
"json": {
"class_id": class_id,
"anchor": anchor,
Expand All @@ -98,25 +83,10 @@ def main(


class Transform:
def __init__(
self,
root: Path,
resample_rate: int,
durations: int,
) -> None:
from torch_wae.network import Preprocess

super().__init__()

self.__root = root
self.__resample_rate = resample_rate
self.__durations = durations
self.__preprocess = Preprocess()

def __call__(self, example: Any) -> tuple[str, str, str, int, torch.Tensor, bool]:
from torch_wae import fs
def __init__(self) -> None: ...

root = self.__root
def __call__(self, example: Any) -> tuple[str, str, str, int]:
from torch_wae import fs

anchor = Path(example["anchor"])
positive = Path(example["positive"])
Expand All @@ -125,38 +95,13 @@ def __call__(self, example: Any) -> tuple[str, str, str, int, torch.Tensor, bool
basename_anchor = fs.basename(anchor)
basename_positive = fs.basename(positive)

path_anchor = anchor.parent / f"{basename_anchor}.npy"
path_positive = positive.parent / f"{basename_positive}.npy"

key = str(path / f"{basename_anchor}-{basename_positive}")
class_id = int(example["class_id"])

mel_anchor, ignore_anchor = self.convert_to_melspec(root / anchor)
mel_positive, ignore_positive = self.convert_to_melspec(root / positive)
melspec = torch.stack((mel_anchor, mel_positive))

ignore = ignore_anchor or ignore_positive

return str(anchor), str(positive), key, class_id, melspec, ignore

def convert_to_melspec(self, path: Path) -> tuple[torch.Tensor, bool]:
from torch_wae.audio import crop_or_pad_last

resample_rate = self.__resample_rate
durations = self.__durations
preprocess = self.__preprocess

waveform, sample_rate = torchaudio.load(path)
waveform = torch.mean(waveform, dim=0).unsqueeze(0)

if resample_rate != sample_rate:
waveform = FA.resample(waveform, sample_rate, resample_rate)

frames = waveform.shape[-1]

waveform = crop_or_pad_last(resample_rate, durations, waveform)
melspec = preprocess(waveform)[0]

ignore = frames > durations * resample_rate

return melspec, ignore
return str(path_anchor), str(path_positive), key, class_id


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 871afcd

Please sign in to comment.