Skip to content

Commit

Permalink
Convert dataset to WebDataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Oct 31, 2024
1 parent c5f4f31 commit 096aa27
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 1 deletion.
28 changes: 27 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ dependencies = [
"pydantic>=2.8.2",
"onnx>=1.16.1",
"convmelspec @ git+https://github.com/adobe-research/convmelspec@4c797b24175df51431ceb374ee57843e1cb2eaf0",
"webdataset>=0.2.100",
]
requires-python = ">=3.12,<3.13"
readme = "README.md"
license = {file = "LICENSE"}

[project.scripts]
convert_dataset_to_jsonl = "torch_wae.cli.convert_dataset_to_jsonl:app"
convert_dataset_to_web = "torch_wae.cli.convert_dataset_to_web:app"
copy_dataset = "torch_wae.cli.copy_dataset:app"
export_to_onnx = "torch_wae.cli.export_to_onnx:app"
generate_pair = "torch_wae.cli.generate_pair:app"
Expand Down
16 changes: 16 additions & 0 deletions src/torch_wae/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ def crop_randomly(
return torch.cat((waveform, p), dim=-1)


def crop_or_pad_last(
sample_rate: int,
durations: int,
waveform: torch.Tensor,
) -> torch.Tensor:
c, d = waveform.shape
size = sample_rate * durations
pad = max(0, size - d)

if d > size:
return waveform[:, :size]
else:
p = torch.zeros((c, pad), dtype=waveform.dtype).to(waveform.device)
return torch.cat((waveform, p), dim=-1)


def gain_randomly(
min_: float,
max_: float,
Expand Down
49 changes: 49 additions & 0 deletions src/torch_wae/cli/convert_dataset_to_jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from pathlib import Path

import typer

app = typer.Typer()


@app.command()
def main(
annotation: Path = typer.Option(
...,
help="the path of a dataset to be encoded.",
),
output: Path = typer.Option(
...,
help="the output path of a converted JSONL file.,",
),
) -> None:
import json

from tqdm import tqdm

from torch_wae.dataset import ClassificationDataset

dataset = ClassificationDataset(
annotation=annotation,
root=annotation.parent,
)

n = len(dataset)

output.parent.mkdir(parents=True, exist_ok=True)
with tqdm(total=n) as progress, output.open(mode="w") as f:
for i in range(n):
example = dataset[i]
line = json.dumps(
{
"path": example.path,
"class_id": example.class_id,
}
)
f.write(f"{line}\n")
progress.update(1)


if __name__ == "__main__":
app()
99 changes: 99 additions & 0 deletions src/torch_wae/cli/convert_dataset_to_web.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torchaudio
import typer
from torchaudio import functional as FA

app = typer.Typer()


@app.command()
def main(
root: Path = typer.Option(
...,
help="the root path of a directory which contains datasets.",
),
annotation: Path = typer.Option(
...,
help="the path of a dataset to be encoded.",
),
size_shard: int = typer.Option(
100,
help="the max size of a shard (unit: MB)",
),
output: Path = typer.Option(
...,
help="the output path of a directory which stores shards of WebDataset,",
),
) -> None:
import json

from tqdm import tqdm
from webdataset.writer import ShardWriter

n = 0
with annotation.open() as f:
for line in f:
n += 1

output.mkdir(parents=True, exist_ok=True)
pattern = str(output / "%08d.tar")
max_size = size_shard * 1000**2

with tqdm(total=n) as progress:
with annotation.open() as f, ShardWriter(
pattern,
maxsize=max_size,
verbose=0,
) as w:
for line in f:
example = json.loads(line)

path = str(example["path"])
key = path.rsplit(".", maxsplit=1)[0]
class_id = int(example["class_id"])

path_file = root / path
melspec = transform(path_file, resample_rate=16000, durations=1)
if melspec is not None:
w.write(
{
"__key__": key,
"npy": melspec,
"json": {
"class_id": class_id,
"path": path,
},
}
)

progress.update(1)


def transform(path: Path, resample_rate: int, durations: int) -> Optional[np.ndarray]:
from torch_wae.audio import crop_or_pad_last
from torch_wae.network import Preprocess

waveform, sample_rate = torchaudio.load(path)
waveform = torch.mean(waveform, dim=0).unsqueeze(0)

if resample_rate != sample_rate:
waveform = FA.resample(waveform, sample_rate, resample_rate)

frames = waveform.shape[-1]
if frames > durations * resample_rate:
return None

waveform = crop_or_pad_last(resample_rate, durations, waveform)
melspec = Preprocess()(waveform)[0].detach().numpy()

return melspec


if __name__ == "__main__":
app()

0 comments on commit 096aa27

Please sign in to comment.