From d35ec9edb66d1fd21202eef43d9115317c98478f Mon Sep 17 00:00:00 2001 From: "Yi, Jihyeon" Date: Tue, 8 Oct 2024 13:22:49 +0900 Subject: [PATCH] ensure unique and handle VideoFrame too --- src/datumaro/plugins/transforms.py | 85 ++++++++++++++++++++++++++++- tests/unit/test_transforms.py | 86 ++++++++++++++++++++++-------- 2 files changed, 147 insertions(+), 24 deletions(-) diff --git a/src/datumaro/plugins/transforms.py b/src/datumaro/plugins/transforms.py index f1515940c2..2d8707fd32 100644 --- a/src/datumaro/plugins/transforms.py +++ b/src/datumaro/plugins/transforms.py @@ -9,6 +9,7 @@ import os.path as osp import random import re +import string from collections import Counter, defaultdict from copy import deepcopy from enum import Enum, auto @@ -63,7 +64,7 @@ UndefinedAttribute, UndefinedLabel, ) -from datumaro.components.media import Image, TableRow +from datumaro.components.media import Image, TableRow, VideoFrame from datumaro.components.transformer import ItemTransform, Transform from datumaro.util import NOTSET, filter_dict, parse_json_file, parse_str_enum_value, take_by from datumaro.util.annotation_util import find_group_leader, find_instances @@ -595,12 +596,92 @@ def __iter__(self): class IdFromImageName(ItemTransform, CliPlugin): """ - Renames items in the dataset using image file name (without extension). + Renames items in the dataset based on the image file name, excluding the extension.|n + When 'ensure_unique' is enabled, a random suffix is appened to ensure each identifier is unique + in cases where the image name is not distinct. By default, the random suffix is three characters long, + but this can be adjusted with the 'suffix_length' parameter.|n + |n + Examples:|n + |n + |s|s- Renames items without duplication check:|n + + .. code-block:: + + |s|s|s|s%(prog)s|n + |n + |s|s- Renames items with duplication check:|n + + .. code-block:: + + |s|s|s|s%(prog)s --ensure_unique|n + |n + |s|s- Renames items with duplication check and alters the suffix length(default: 3):|n + + .. code-block:: + + |s|s|s|s%(prog)s --ensure_unique --suffix_length 2 """ + DEFAULT_RETRY = 1000 + SUFFIX_LETTERS = string.ascii_lowercase + string.digits + + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument( + "-u", + "--ensure_unique", + action="store_true", + help="Appends a random suffix to ensure each identifier is unique if the image name is duplicated.", + ) + parser.add_argument( + "-l", + "--suffix_length", + type=int, + default=3, + help="Alters the length of the random suffix if the 'ensure_unique' is enabled.", + ) + + return parser + + def __init__(self, extractor, ensure_unique: bool = False, suffix_length: int = 3): + super().__init__(extractor) + self._length = "parent" + self._ensure_unique = ensure_unique + self._names: set[str] = set() + self._suffix_length = suffix_length + if suffix_length <= 0: + raise ValueError( + f"The 'suffix_length' must be greater than 0. Received: {suffix_length}." + ) + self._max_retry = min( + self.DEFAULT_RETRY, pow(len(self.SUFFIX_LETTERS), self._suffix_length) + ) + + def _add_unique_suffix(self, name): + count = 0 + while name in self._names: + suffix = "".join(random.choices(self.SUFFIX_LETTERS, k=self._suffix_length)) + new_name = f"{name}__{suffix}" + if new_name not in self._names: + name = new_name + break + count += 1 + if count == self._max_retry: + raise Exception( + f"Too many duplicate names. Failed to generate a unique suffix after {self._max_retry} attempts." + ) + + self._names.add(name) + return name + def transform_item(self, item): if isinstance(item.media, Image) and hasattr(item.media, "path"): name = osp.splitext(osp.basename(item.media.path))[0] + if isinstance(item.media, VideoFrame): + name += f"_frame-{item.media.index}" + if self._ensure_unique: + name = self._add_unique_suffix(name) return self.wrap_item(item, id=name) else: log.debug("Can't change item id for item '%s': " "item has no path info" % item.id) diff --git a/tests/unit/test_transforms.py b/tests/unit/test_transforms.py index 25f01caff8..79b9247daf 100644 --- a/tests/unit/test_transforms.py +++ b/tests/unit/test_transforms.py @@ -5,7 +5,9 @@ import os import os.path as osp import random +import re from unittest import TestCase +from unittest.mock import patch import numpy as np import pandas as pd @@ -33,10 +35,10 @@ Tabular, TabularCategories, ) -from datumaro.components.dataset import Dataset +from datumaro.components.dataset import Dataset, eager_mode from datumaro.components.dataset_base import DatasetItem from datumaro.components.errors import AnnotationTypeError -from datumaro.components.media import Image, Table, TableRow +from datumaro.components.media import Image, Table, TableRow, Video, VideoFrame from ..requirements import Requirements, mark_bug, mark_requirement @@ -420,26 +422,6 @@ def test_shapes_to_boxes(self): actual = transforms.ShapesToBoxes(source_dataset) compare_datasets(self, target_dataset, actual) - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_id_from_image(self): - source_dataset = Dataset.from_iterable( - [ - DatasetItem(id=1, media=Image.from_file(path="path.jpg")), - DatasetItem(id=2), - DatasetItem(id=3, media=Image.from_numpy(data=np.ones([5, 5, 3]))), - ] - ) - target_dataset = Dataset.from_iterable( - [ - DatasetItem(id="path", media=Image.from_file(path="path.jpg")), - DatasetItem(id=2), - DatasetItem(id=3, media=Image.from_numpy(data=np.ones([5, 5, 3]))), - ] - ) - - actual = transforms.IdFromImageName(source_dataset) - compare_datasets(self, target_dataset, actual) - @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_boxes_to_masks(self): source_dataset = Dataset.from_iterable( @@ -1227,6 +1209,66 @@ def test_annotation_reindex(self, fxt_dataset: Dataset, reindex_each_item: bool) ) +class IdFromImageNameTest: + @pytest.fixture + def fxt_dataset(self, n_labels=3, n_anns=5, n_items=7) -> Dataset: + video = Video("video.mp4") + return Dataset.from_iterable( + [ + DatasetItem(id=1, media=Image.from_file(path="path1.jpg")), + DatasetItem(id=2, media=Image.from_file(path="path1.jpg")), + DatasetItem(id=3, media=Image.from_file(path="path1.jpg")), + DatasetItem(id=4, media=VideoFrame(video, index=30)), + DatasetItem(id=5, media=VideoFrame(video, index=30)), + DatasetItem(id=6, media=VideoFrame(video, index=60)), + DatasetItem(id=7), + DatasetItem(id=8, media=Image.from_numpy(data=np.ones([5, 5, 3]))), + ] + ) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @pytest.mark.parametrize("ensure_unique", [True, False]) + def test_id_from_image(self, fxt_dataset, ensure_unique): + source_dataset = fxt_dataset + actual_dataset = transforms.IdFromImageName(source_dataset, ensure_unique=ensure_unique) + + unique_names: set[str] = set() + for src, actual in zip(source_dataset, actual_dataset): + if not isinstance(src.media, Image) or not hasattr(src.media, "path"): + src == actual + else: + if isinstance(src.media, VideoFrame): + expected_id = f"video_frame-{src.media.index}" + else: + expected_id = os.path.splitext(src.media.path)[0] + if ensure_unique: + assert actual.id.startswith(expected_id) + assert actual.wrap(id=src.id) == src + assert actual.id not in unique_names + unique_names.add(actual.id) + else: + assert actual == src.wrap(id=expected_id) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_id_from_image_wrong_suffix_length(self, fxt_dataset): + with pytest.raises(ValueError) as e: + transforms.IdFromImageName(fxt_dataset, ensure_unique=True, suffix_length=0) + assert str(e.value).startswith("The 'suffix_length' must be greater than 0.") + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_id_from_image_too_many_duplication(self, fxt_dataset): + with patch("datumaro.plugins.transforms.IdFromImageName.DEFAULT_RETRY", 1), patch( + "datumaro.plugins.transforms.IdFromImageName.SUFFIX_LETTERS", "a" + ), pytest.raises(Exception) as e: + with eager_mode(): + fxt_dataset.transform( + "id_from_image_name", + ensure_unique=True, + suffix_length=1, + ) + assert str(e.value).startswith("Too many duplicate names.") + + class AstypeAnnotationsTest(TestCase): def setUp(self): self.table = Table.from_list(