Skip to content

Commit

Permalink
Merge branch 'generate-pair-with-limited-classes'
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Oct 17, 2024
2 parents 327689a + c574580 commit 1d1900f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ readme = "README.md"
license = {file = "LICENSE"}

[project.scripts]
copy_dataset = "torch_wae.cli.copy_dataset:app"
export_to_onnx = "torch_wae.cli.export_to_onnx:app"
generate_pair = "torch_wae.cli.generate_pair:app"

Expand Down
53 changes: 53 additions & 0 deletions src/torch_wae/cli/copy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from pathlib import Path

import typer

app = typer.Typer()


@app.command()
def main(
dataset: Path = typer.Option(..., help="the path of a dataset to be copied."),
output: Path = typer.Option(..., help=""),
) -> None:
import json
import shutil

from tqdm import tqdm

root_dataset = dataset.parent

with dataset.open() as f:
n = sum(1 for _ in f)

with tqdm(total=n) as progress, dataset.open() as f:
for line in f:
json_example = json.loads(line)

path_rel_anchor = str(json_example["anchor"])
path_rel_positive = str(json_example["positive"])

path_anchor_src = root_dataset / path_rel_anchor
path_positive_src = root_dataset / path_rel_positive

path_anchor_dst = output / path_rel_anchor
path_positive_dst = output / path_rel_positive

path_anchor_dst.parent.mkdir(parents=True, exist_ok=True)
path_positive_dst.parent.mkdir(parents=True, exist_ok=True)

shutil.copy(path_anchor_src, path_anchor_dst)
shutil.copy(path_positive_src, path_positive_dst)

progress.update(1)

output.mkdir(parents=True, exist_ok=True)

dataset_copied = output / dataset.name
dataset_copied.write_bytes(dataset.read_bytes())


if __name__ == "__main__":
app()
15 changes: 14 additions & 1 deletion src/torch_wae/cli/generate_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pathlib import Path
from random import Random
from typing import Optional

import typer

Expand All @@ -20,6 +21,10 @@ def generate_pair(
20240820,
help="the random seed",
),
max_classes: Optional[int] = typer.Option(
...,
help="",
),
annotation: Path = typer.Option(
...,
help="the path of an annotation file for classification",
Expand All @@ -37,6 +42,13 @@ def generate_pair(
with annotation.open() as f:
dataset = ClassificationDatasetJson(**json.load(f))

if max_classes is None:
classes = dataset.classes
else:
classes = tuple(random.sample(dataset.classes, max_classes))

set_classes = frozenset(classes)

n_class = len(dataset.classes)
n_example = len(dataset.examples)

Expand All @@ -46,7 +58,8 @@ def generate_pair(

with tqdm(total=n_example) as progress:
for example in dataset.examples:
group_example[example.class_id].append(example)
if dataset.classes[example.class_id] in set_classes:
group_example[example.class_id].append(example)

progress.update(1)

Expand Down

0 comments on commit 1d1900f

Please sign in to comment.