From 871afcd003f56079f290a3fc2630e7b5d831f1f7 Mon Sep 17 00:00:00 2001 From: Tomoya Kose Date: Sat, 9 Nov 2024 19:46:55 +0900 Subject: [PATCH] Convert metadata and mel-spec separately. --- pyproject.toml | 1 + src/torch_wae/cli/convert_dataset_to_web.py | 70 ++-------- src/torch_wae/cli/convert_pair_to_web.py | 77 ++--------- src/torch_wae/cli/convert_to_melspec.py | 134 ++++++++++++++++++++ 4 files changed, 158 insertions(+), 124 deletions(-) create mode 100644 src/torch_wae/cli/convert_to_melspec.py diff --git a/pyproject.toml b/pyproject.toml index 4a7a96e..c31d96d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ license = {file = "LICENSE"} convert_dataset_to_jsonl = "torch_wae.cli.convert_dataset_to_jsonl:app" convert_dataset_to_web = "torch_wae.cli.convert_dataset_to_web:app" convert_pair_to_web = "torch_wae.cli.convert_pair_to_web:app" +convert_to_melspec = "torch_wae.cli.convert_to_melspec:app" copy_dataset = "torch_wae.cli.copy_dataset:app" export_to_onnx = "torch_wae.cli.export_to_onnx:app" generate_pair = "torch_wae.cli.generate_pair:app" diff --git a/src/torch_wae/cli/convert_dataset_to_web.py b/src/torch_wae/cli/convert_dataset_to_web.py index 24365b7..a996780 100644 --- a/src/torch_wae/cli/convert_dataset_to_web.py +++ b/src/torch_wae/cli/convert_dataset_to_web.py @@ -3,10 +3,7 @@ from pathlib import Path from typing import Any -import torch -import torchaudio import typer -from torchaudio import functional as FA app = typer.Typer() @@ -26,12 +23,12 @@ def main( help="the path of a dataset to be encoded.", ), size_shard: int = typer.Option( - 100, + 10, help="the max size of a shard (unit: MB)", ), output: Path = typer.Option( ..., - help="the output path of a directory which stores shards of WebDataset,", + help="the output path of a directory which stores shards of WebDataset.", ), ) -> None: from torch.utils.data import DataLoader @@ -47,14 +44,7 @@ def main( n += 1 loader = DataLoader( - transformed_iter( - dataset, - Transform( - root=root, - resample_rate=16000, - durations=1, - ), - ), + transformed_iter(dataset, Transform()), batch_size=1, shuffle=False, num_workers=n_workers, @@ -66,23 +56,17 @@ def main( with tqdm(total=n) as progress: with ShardWriter(pattern, maxsize=max_size, verbose=0) as w: - for i, (path, key, class_id, melspec, ignore) in enumerate(loader): + for i, (path, key, class_id) in enumerate(loader): if n <= i: break - ignore = bool(ignore.detach().numpy()[0]) - if ignore: - continue - path = path[0] key = key[0] class_id = int(class_id.detach().numpy()[0]) - melspec = melspec.detach().numpy()[0] w.write( { "__key__": key, - "npy": melspec, "json": { "class_id": class_id, "path": path, @@ -94,49 +78,19 @@ def main( class Transform: - def __init__( - self, - root: Path, - resample_rate: int, - durations: int, - ) -> None: - from torch_wae.network import Preprocess - + def __init__(self) -> None: super().__init__() - self.__root = root - self.__resample_rate = resample_rate - self.__durations = durations - self.__preprocess = Preprocess() - - def __call__(self, example: Any) -> tuple[str, str, int, torch.Tensor, bool]: - from torch_wae.audio import crop_or_pad_last + def __call__(self, example: Any) -> tuple[str, str, int]: + from torch_wae import fs - root = self.__root - resample_rate = self.__resample_rate - durations = self.__durations - preprocess = self.__preprocess - - path = str(example["path"]) - key = path.rsplit(".", maxsplit=1)[0] + path = Path(example["path"]) + basename = fs.basename(path) + path_npy = str(path.parent / f"{basename}.npy") + key = str(path.parent / basename) class_id = int(example["class_id"]) - path_file = root / path - - waveform, sample_rate = torchaudio.load(path_file) - waveform = torch.mean(waveform, dim=0).unsqueeze(0) - - if resample_rate != sample_rate: - waveform = FA.resample(waveform, sample_rate, resample_rate) - - frames = waveform.shape[-1] - - waveform = crop_or_pad_last(resample_rate, durations, waveform) - melspec = preprocess(waveform)[0] - - ignore = frames > durations * resample_rate - - return path, key, class_id, melspec, ignore + return path_npy, key, class_id if __name__ == "__main__": diff --git a/src/torch_wae/cli/convert_pair_to_web.py b/src/torch_wae/cli/convert_pair_to_web.py index 0346c53..7bcd1c0 100644 --- a/src/torch_wae/cli/convert_pair_to_web.py +++ b/src/torch_wae/cli/convert_pair_to_web.py @@ -3,10 +3,7 @@ from pathlib import Path from typing import Any -import torch -import torchaudio import typer -from torchaudio import functional as FA app = typer.Typer() @@ -26,12 +23,12 @@ def main( help="the path of a dataset to be encoded.", ), size_shard: int = typer.Option( - 200, + 20, help="the max size of a shard (unit: MB)", ), output: Path = typer.Option( ..., - help="the output path of a directory which stores shards of WebDataset,", + help="the output path of a directory which stores shards of WebDataset.", ), ) -> None: from torch.utils.data import DataLoader @@ -49,11 +46,7 @@ def main( loader = DataLoader( transformed_iter( dataset, - Transform( - root=root, - resample_rate=16000, - durations=1, - ), + Transform(), ), batch_size=1, shuffle=False, @@ -66,26 +59,18 @@ def main( with tqdm(total=n) as progress: with ShardWriter(pattern, maxsize=max_size, verbose=0) as w: - for i, (anchor, positive, key, class_id, melspec, ignore) in enumerate( - loader - ): + for i, (anchor, positive, key, class_id) in enumerate(loader): if n <= i: break - ignore = bool(ignore.detach().numpy()[0]) - if ignore: - continue - anchor = anchor[0] positive = positive[0] key = key[0] class_id = int(class_id.detach().numpy()[0]) - melspec = melspec.detach().numpy()[0] w.write( { "__key__": key, - "npy": melspec, "json": { "class_id": class_id, "anchor": anchor, @@ -98,25 +83,10 @@ def main( class Transform: - def __init__( - self, - root: Path, - resample_rate: int, - durations: int, - ) -> None: - from torch_wae.network import Preprocess - - super().__init__() - - self.__root = root - self.__resample_rate = resample_rate - self.__durations = durations - self.__preprocess = Preprocess() - - def __call__(self, example: Any) -> tuple[str, str, str, int, torch.Tensor, bool]: - from torch_wae import fs + def __init__(self) -> None: ... - root = self.__root + def __call__(self, example: Any) -> tuple[str, str, str, int]: + from torch_wae import fs anchor = Path(example["anchor"]) positive = Path(example["positive"]) @@ -125,38 +95,13 @@ def __call__(self, example: Any) -> tuple[str, str, str, int, torch.Tensor, bool basename_anchor = fs.basename(anchor) basename_positive = fs.basename(positive) + path_anchor = anchor.parent / f"{basename_anchor}.npy" + path_positive = positive.parent / f"{basename_positive}.npy" + key = str(path / f"{basename_anchor}-{basename_positive}") class_id = int(example["class_id"]) - mel_anchor, ignore_anchor = self.convert_to_melspec(root / anchor) - mel_positive, ignore_positive = self.convert_to_melspec(root / positive) - melspec = torch.stack((mel_anchor, mel_positive)) - - ignore = ignore_anchor or ignore_positive - - return str(anchor), str(positive), key, class_id, melspec, ignore - - def convert_to_melspec(self, path: Path) -> tuple[torch.Tensor, bool]: - from torch_wae.audio import crop_or_pad_last - - resample_rate = self.__resample_rate - durations = self.__durations - preprocess = self.__preprocess - - waveform, sample_rate = torchaudio.load(path) - waveform = torch.mean(waveform, dim=0).unsqueeze(0) - - if resample_rate != sample_rate: - waveform = FA.resample(waveform, sample_rate, resample_rate) - - frames = waveform.shape[-1] - - waveform = crop_or_pad_last(resample_rate, durations, waveform) - melspec = preprocess(waveform)[0] - - ignore = frames > durations * resample_rate - - return melspec, ignore + return str(path_anchor), str(path_positive), key, class_id if __name__ == "__main__": diff --git a/src/torch_wae/cli/convert_to_melspec.py b/src/torch_wae/cli/convert_to_melspec.py new file mode 100644 index 0000000..8ff1f92 --- /dev/null +++ b/src/torch_wae/cli/convert_to_melspec.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torchaudio +import typer +from torchaudio import functional as FA + +app = typer.Typer() + + +@app.command() +def main( + n_workers: int = typer.Option( + ..., + help="the number of workers used to transform dataset", + ), + root: Path = typer.Option( + ..., + help="the root path of a directory which contains datasets.", + ), + annotation: Path = typer.Option( + ..., + help="the path of a dataset to be encoded.", + ), + output: Path = typer.Option( + ..., + help="the output path of a directory which stores mel spectrograms.", + ), +) -> None: + from torch.utils.data import DataLoader + from tqdm import tqdm + + from torch_wae.dataset import JsonLinesDataset, transformed_iter + + dataset = JsonLinesDataset(annotation) + + n = 0 + for _ in dataset: + n += 1 + + loader = DataLoader( + transformed_iter( + dataset, + Transform( + root=root, + resample_rate=16000, + durations=1, + ), + ), + batch_size=1, + shuffle=False, + num_workers=n_workers, + ) + + output.mkdir(parents=True, exist_ok=True) + + with tqdm(total=n) as progress: + for i, (path, key, class_id, melspec, ignore) in enumerate(loader): + if n <= i: + break + + ignore = bool(ignore.detach().numpy()[0]) + if ignore: + continue + + path = path[0] + key = key[0] + class_id = int(class_id.detach().numpy()[0]) + melspec = melspec.detach().numpy()[0] + + path_melspec = output / path + path_melspec.parent.mkdir(parents=True, exist_ok=True) + + np.save(str(path_melspec), melspec) + + progress.update(1) + + +class Transform: + def __init__( + self, + root: Path, + resample_rate: int, + durations: int, + ) -> None: + from torch_wae.network import Preprocess + + super().__init__() + + self.__root = root + self.__resample_rate = resample_rate + self.__durations = durations + self.__preprocess = Preprocess() + + def __call__(self, example: Any) -> tuple[str, str, torch.Tensor, bool]: + from torch_wae import fs + from torch_wae.audio import crop_or_pad_last + + root = self.__root + resample_rate = self.__resample_rate + durations = self.__durations + preprocess = self.__preprocess + + path = Path(example["path"]) + key = str(path).rsplit(".", maxsplit=1)[0] + class_id = int(example["class_id"]) + + path_file = root / path + + waveform, sample_rate = torchaudio.load(path_file) + waveform = torch.mean(waveform, dim=0).unsqueeze(0) + + if resample_rate != sample_rate: + waveform = FA.resample(waveform, sample_rate, resample_rate) + + frames = waveform.shape[-1] + + waveform = crop_or_pad_last(resample_rate, durations, waveform) + melspec = preprocess(waveform)[0] + + ignore = frames > durations * resample_rate + + basename_melspec = fs.basename(path) + path_melspec = str(path.parent / f"{basename_melspec}.npy") + + return path_melspec, key, class_id, melspec, ignore + + +if __name__ == "__main__": + app()