diff --git a/pyproject.toml b/pyproject.toml index c6e8761..4a7a96e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ license = {file = "LICENSE"} [project.scripts] convert_dataset_to_jsonl = "torch_wae.cli.convert_dataset_to_jsonl:app" convert_dataset_to_web = "torch_wae.cli.convert_dataset_to_web:app" +convert_pair_to_web = "torch_wae.cli.convert_pair_to_web:app" copy_dataset = "torch_wae.cli.copy_dataset:app" export_to_onnx = "torch_wae.cli.export_to_onnx:app" generate_pair = "torch_wae.cli.generate_pair:app" diff --git a/src/torch_wae/cli/convert_pair_to_web.py b/src/torch_wae/cli/convert_pair_to_web.py new file mode 100644 index 0000000..0346c53 --- /dev/null +++ b/src/torch_wae/cli/convert_pair_to_web.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import torch +import torchaudio +import typer +from torchaudio import functional as FA + +app = typer.Typer() + + +@app.command() +def main( + n_workers: int = typer.Option( + ..., + help="the number of workers used to transform dataset", + ), + root: Path = typer.Option( + ..., + help="the root path of a directory which contains datasets.", + ), + annotation: Path = typer.Option( + ..., + help="the path of a dataset to be encoded.", + ), + size_shard: int = typer.Option( + 200, + help="the max size of a shard (unit: MB)", + ), + output: Path = typer.Option( + ..., + help="the output path of a directory which stores shards of WebDataset,", + ), +) -> None: + from torch.utils.data import DataLoader + from tqdm import tqdm + from webdataset.writer import ShardWriter + + from torch_wae.dataset import JsonLinesDataset, transformed_iter + + dataset = JsonLinesDataset(annotation) + + n = 0 + for _ in dataset: + n += 1 + + loader = DataLoader( + transformed_iter( + dataset, + Transform( + root=root, + resample_rate=16000, + durations=1, + ), + ), + batch_size=1, + shuffle=False, + num_workers=n_workers, + ) + + output.mkdir(parents=True, exist_ok=True) + pattern = str(output / "%04d.tar") + max_size = size_shard * 1024**2 + + with tqdm(total=n) as progress: + with ShardWriter(pattern, maxsize=max_size, verbose=0) as w: + for i, (anchor, positive, key, class_id, melspec, ignore) in enumerate( + loader + ): + if n <= i: + break + + ignore = bool(ignore.detach().numpy()[0]) + if ignore: + continue + + anchor = anchor[0] + positive = positive[0] + key = key[0] + class_id = int(class_id.detach().numpy()[0]) + melspec = melspec.detach().numpy()[0] + + w.write( + { + "__key__": key, + "npy": melspec, + "json": { + "class_id": class_id, + "anchor": anchor, + "positive": positive, + }, + } + ) + + progress.update(1) + + +class Transform: + def __init__( + self, + root: Path, + resample_rate: int, + durations: int, + ) -> None: + from torch_wae.network import Preprocess + + super().__init__() + + self.__root = root + self.__resample_rate = resample_rate + self.__durations = durations + self.__preprocess = Preprocess() + + def __call__(self, example: Any) -> tuple[str, str, str, int, torch.Tensor, bool]: + from torch_wae import fs + + root = self.__root + + anchor = Path(example["anchor"]) + positive = Path(example["positive"]) + + path = anchor.parent + basename_anchor = fs.basename(anchor) + basename_positive = fs.basename(positive) + + key = str(path / f"{basename_anchor}-{basename_positive}") + class_id = int(example["class_id"]) + + mel_anchor, ignore_anchor = self.convert_to_melspec(root / anchor) + mel_positive, ignore_positive = self.convert_to_melspec(root / positive) + melspec = torch.stack((mel_anchor, mel_positive)) + + ignore = ignore_anchor or ignore_positive + + return str(anchor), str(positive), key, class_id, melspec, ignore + + def convert_to_melspec(self, path: Path) -> tuple[torch.Tensor, bool]: + from torch_wae.audio import crop_or_pad_last + + resample_rate = self.__resample_rate + durations = self.__durations + preprocess = self.__preprocess + + waveform, sample_rate = torchaudio.load(path) + waveform = torch.mean(waveform, dim=0).unsqueeze(0) + + if resample_rate != sample_rate: + waveform = FA.resample(waveform, sample_rate, resample_rate) + + frames = waveform.shape[-1] + + waveform = crop_or_pad_last(resample_rate, durations, waveform) + melspec = preprocess(waveform)[0] + + ignore = frames > durations * resample_rate + + return melspec, ignore + + +if __name__ == "__main__": + app()