Skip to content

Commit

Permalink
Merge branch 'convert-pair-to-melspecs'
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Nov 8, 2024
2 parents 18aa522 + c2d687a commit fe89d55
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ license = {file = "LICENSE"}
[project.scripts]
convert_dataset_to_jsonl = "torch_wae.cli.convert_dataset_to_jsonl:app"
convert_dataset_to_web = "torch_wae.cli.convert_dataset_to_web:app"
convert_pair_to_web = "torch_wae.cli.convert_pair_to_web:app"
copy_dataset = "torch_wae.cli.copy_dataset:app"
export_to_onnx = "torch_wae.cli.export_to_onnx:app"
generate_pair = "torch_wae.cli.generate_pair:app"
Expand Down
163 changes: 163 additions & 0 deletions src/torch_wae/cli/convert_pair_to_web.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

import torch
import torchaudio
import typer
from torchaudio import functional as FA

app = typer.Typer()


@app.command()
def main(
n_workers: int = typer.Option(
...,
help="the number of workers used to transform dataset",
),
root: Path = typer.Option(
...,
help="the root path of a directory which contains datasets.",
),
annotation: Path = typer.Option(
...,
help="the path of a dataset to be encoded.",
),
size_shard: int = typer.Option(
200,
help="the max size of a shard (unit: MB)",
),
output: Path = typer.Option(
...,
help="the output path of a directory which stores shards of WebDataset,",
),
) -> None:
from torch.utils.data import DataLoader
from tqdm import tqdm
from webdataset.writer import ShardWriter

from torch_wae.dataset import JsonLinesDataset, transformed_iter

dataset = JsonLinesDataset(annotation)

n = 0
for _ in dataset:
n += 1

loader = DataLoader(
transformed_iter(
dataset,
Transform(
root=root,
resample_rate=16000,
durations=1,
),
),
batch_size=1,
shuffle=False,
num_workers=n_workers,
)

output.mkdir(parents=True, exist_ok=True)
pattern = str(output / "%04d.tar")
max_size = size_shard * 1024**2

with tqdm(total=n) as progress:
with ShardWriter(pattern, maxsize=max_size, verbose=0) as w:
for i, (anchor, positive, key, class_id, melspec, ignore) in enumerate(
loader
):
if n <= i:
break

ignore = bool(ignore.detach().numpy()[0])
if ignore:
continue

anchor = anchor[0]
positive = positive[0]
key = key[0]
class_id = int(class_id.detach().numpy()[0])
melspec = melspec.detach().numpy()[0]

w.write(
{
"__key__": key,
"npy": melspec,
"json": {
"class_id": class_id,
"anchor": anchor,
"positive": positive,
},
}
)

progress.update(1)


class Transform:
def __init__(
self,
root: Path,
resample_rate: int,
durations: int,
) -> None:
from torch_wae.network import Preprocess

super().__init__()

self.__root = root
self.__resample_rate = resample_rate
self.__durations = durations
self.__preprocess = Preprocess()

def __call__(self, example: Any) -> tuple[str, str, str, int, torch.Tensor, bool]:
from torch_wae import fs

root = self.__root

anchor = Path(example["anchor"])
positive = Path(example["positive"])

path = anchor.parent
basename_anchor = fs.basename(anchor)
basename_positive = fs.basename(positive)

key = str(path / f"{basename_anchor}-{basename_positive}")
class_id = int(example["class_id"])

mel_anchor, ignore_anchor = self.convert_to_melspec(root / anchor)
mel_positive, ignore_positive = self.convert_to_melspec(root / positive)
melspec = torch.stack((mel_anchor, mel_positive))

ignore = ignore_anchor or ignore_positive

return str(anchor), str(positive), key, class_id, melspec, ignore

def convert_to_melspec(self, path: Path) -> tuple[torch.Tensor, bool]:
from torch_wae.audio import crop_or_pad_last

resample_rate = self.__resample_rate
durations = self.__durations
preprocess = self.__preprocess

waveform, sample_rate = torchaudio.load(path)
waveform = torch.mean(waveform, dim=0).unsqueeze(0)

if resample_rate != sample_rate:
waveform = FA.resample(waveform, sample_rate, resample_rate)

frames = waveform.shape[-1]

waveform = crop_or_pad_last(resample_rate, durations, waveform)
melspec = preprocess(waveform)[0]

ignore = frames > durations * resample_rate

return melspec, ignore


if __name__ == "__main__":
app()

0 comments on commit fe89d55

Please sign in to comment.