-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
194 additions
and
1 deletion.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
|
||
import typer | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.command() | ||
def main( | ||
annotation: Path = typer.Option( | ||
..., | ||
help="the path of a dataset to be encoded.", | ||
), | ||
output: Path = typer.Option( | ||
..., | ||
help="the output path of a converted JSONL file.,", | ||
), | ||
) -> None: | ||
import json | ||
|
||
from tqdm import tqdm | ||
|
||
from torch_wae.dataset import ClassificationDataset | ||
|
||
dataset = ClassificationDataset( | ||
annotation=annotation, | ||
root=annotation.parent, | ||
) | ||
|
||
n = len(dataset) | ||
|
||
output.parent.mkdir(parents=True, exist_ok=True) | ||
with tqdm(total=n) as progress, output.open(mode="w") as f: | ||
for i in range(n): | ||
example = dataset[i] | ||
line = json.dumps( | ||
{ | ||
"path": example.path, | ||
"class_id": example.class_id, | ||
} | ||
) | ||
f.write(f"{line}\n") | ||
progress.update(1) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torchaudio | ||
import typer | ||
from torchaudio import functional as FA | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.command() | ||
def main( | ||
root: Path = typer.Option( | ||
..., | ||
help="the root path of a directory which contains datasets.", | ||
), | ||
annotation: Path = typer.Option( | ||
..., | ||
help="the path of a dataset to be encoded.", | ||
), | ||
size_shard: int = typer.Option( | ||
100, | ||
help="the max size of a shard (unit: MB)", | ||
), | ||
output: Path = typer.Option( | ||
..., | ||
help="the output path of a directory which stores shards of WebDataset,", | ||
), | ||
) -> None: | ||
import json | ||
|
||
from tqdm import tqdm | ||
from webdataset.writer import ShardWriter | ||
|
||
n = 0 | ||
with annotation.open() as f: | ||
for line in f: | ||
n += 1 | ||
|
||
output.mkdir(parents=True, exist_ok=True) | ||
pattern = str(output / "%08d.tar") | ||
max_size = size_shard * 1000**2 | ||
|
||
with tqdm(total=n) as progress: | ||
with annotation.open() as f, ShardWriter( | ||
pattern, | ||
maxsize=max_size, | ||
verbose=0, | ||
) as w: | ||
for line in f: | ||
example = json.loads(line) | ||
|
||
path = str(example["path"]) | ||
key = path.rsplit(".", maxsplit=1)[0] | ||
class_id = int(example["class_id"]) | ||
|
||
path_file = root / path | ||
melspec = transform(path_file, resample_rate=16000, durations=1) | ||
if melspec is not None: | ||
w.write( | ||
{ | ||
"__key__": key, | ||
"npy": melspec, | ||
"json": { | ||
"class_id": class_id, | ||
"path": path, | ||
}, | ||
} | ||
) | ||
|
||
progress.update(1) | ||
|
||
|
||
def transform(path: Path, resample_rate: int, durations: int) -> Optional[np.ndarray]: | ||
from torch_wae.audio import crop_or_pad_last | ||
from torch_wae.network import Preprocess | ||
|
||
waveform, sample_rate = torchaudio.load(path) | ||
waveform = torch.mean(waveform, dim=0).unsqueeze(0) | ||
|
||
if resample_rate != sample_rate: | ||
waveform = FA.resample(waveform, sample_rate, resample_rate) | ||
|
||
frames = waveform.shape[-1] | ||
if frames > durations * resample_rate: | ||
return None | ||
|
||
waveform = crop_or_pad_last(resample_rate, durations, waveform) | ||
melspec = Preprocess()(waveform)[0].detach().numpy() | ||
|
||
return melspec | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |