Skip to content

Commit

Permalink
Add ExtractedMask and update importers who can use it to use it (#1480)
Browse files Browse the repository at this point in the history
### Summary

- Ticket no. 138711
- Add `ExtractedMask` which is a dedicated mask annotation class to use
a single index mask source. The index mask is a integer 2D array and its
pixel can indicate a label id (class) or instance id.
- The advantage of introducing this class is that we can minimize a
memory burden by sharing a single source explicitly. In addition, we can
use a computational burden for the following scenario:

| | Create a semantic segmentation map by combining Datumaro Mask
annotations |
| :-: | :-: |
| Setup | Assume there is a Datumaro dataset item and it has multiple
Datumaro Mask. We can merge them into one index mask to construct a
semantic segmentation map. |
| Before | 1) For each Datumaro Mask annotation, create a binary masks,
2) Merge binary masks into one index mask |
| After | 1) Check all Mask annotations are `ExtractedMask` and it share
the same index mask, 2) Just extract the source index mask |

It can save huge computations required for 2D array manipulations:
creating a binary mask and merging binary masks.

### How to test
<!-- Describe the testing procedure for reviewers, if changes are
not fully covered by unit tests or manual testing can be complicated.
-->

### Checklist
<!-- Put an 'x' in all the boxes that apply -->
- [x] I have added unit tests to cover my changes.​
- [ ] I have added integration tests to cover my changes.​
- [x] I have added the description of my changes into
[CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​
- [ ] I have updated the
[documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs)
accordingly

### License

- [x] I submit _my code changes_ under the same [MIT
License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE)
that covers the project.
  Feel free to contact the maintainers if that's a concern.
- [ ] I have updated the license header for each file (see an example
below).

```python
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
```

---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Apr 29, 2024
1 parent f57231d commit d196fc8
Show file tree
Hide file tree
Showing 19 changed files with 193 additions and 106 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1442>)
- Get target information for tabular dataset
(<https://github.com/openvinotoolkit/datumaro/pull/1471>)
- Add ExtractedMask and update importers who can use it to use it
(<https://github.com/openvinotoolkit/datumaro/pull/1480>)

## May 2024 Release 1.6.1
### Enhancements
Expand Down
43 changes: 42 additions & 1 deletion src/datumaro/components/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def __eq__(self, other):


BinaryMaskImage = np.ndarray # 2d array of type bool
BinaryMaskImageCallable = Callable[[], BinaryMaskImage]
IndexMaskImage = np.ndarray # 2d array of type int
IndexMaskImageCallable = Callable[[], IndexMaskImage]


@attrs(slots=True, eq=False, order=False)
Expand All @@ -355,7 +357,7 @@ class Mask(Annotation):
"""

_type = AnnotationType.mask
_image = field()
_image: Union[BinaryMaskImage, BinaryMaskImageCallable] = field()
label: Optional[int] = field(
converter=attr.converters.optional(int), default=None, kw_only=True
)
Expand Down Expand Up @@ -501,6 +503,45 @@ def __eq__(self, other):
return self.rle == other.rle


@attrs(slots=True, eq=False, order=False)
class ExtractedMask(Mask):
"""Mask annotation (binary mask) extracted from an index mask (integer 2D Numpy array).
This class can extract a binary mask with given index mask and index value.
The advantage of this class is that we can create multiple binary mask but they share a single index mask source.
Attributes:
index_mask: Integer 2D Numpy array. Its pixel can indicate a label id (class) or an instance id.
index: Integer value to extract a binary mask from the given index mask.
Examples:
This example demonstrates how to create an `ExtractedMask` from a synthetic index mask,
which denotes a semantic segmentation mask with binary values such as 0 for background
and 1 for foreground.
>>> import numpy as np
>>> from datumaro.components.annotation import ExtractedMask
>>>
>>> index_mask = np.random.randint(low=0, high=2, size=(10, 10), dtype=np.uint8)
>>> mask1 = ExtractedMask(index_mask=index_mask, index=0, label=0) # 0 for background
>>> mask2 = ExtractedMask(index_mask=index_mask, index=1, label=1) # 1 for foreground
>>> np.unique(mask1.image).tolist() # `image` property create a binary mask
np.array([0, 1])
>>> mask1.index_mask == mask2.index_mask # They share the same source
True
"""

index_mask: Union[IndexMaskImage, IndexMaskImageCallable] = field()
index: int = field()

_image: None = field(init=False, default=None)

@property
def image(self) -> BinaryMaskImage:
index_mask = self.index_mask() if callable(self.index_mask) else self.index_mask
return index_mask == self.index


CompiledMaskImage = np.ndarray # 2d of integers (of different precision)


Expand Down
8 changes: 4 additions & 4 deletions src/datumaro/plugins/data_formats/ade20k2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np

from datumaro.components.annotation import AnnotationType, CompiledMask, LabelCategories, Mask
from datumaro.components.annotation import AnnotationType, ExtractedMask, LabelCategories
from datumaro.components.dataset_base import DatasetBase, DatasetItem
from datumaro.components.errors import InvalidAnnotationError
from datumaro.components.format_detection import FormatDetectionContext
Expand Down Expand Up @@ -97,7 +97,6 @@ def _load_items(self, subset):
continue

mask = lazy_image(mask_path, loader=self._load_instance_mask)
mask = CompiledMask(instance_mask=mask)

for v in item_info:
if v["part_level"] != part_level:
Expand All @@ -108,9 +107,10 @@ def _load_items(self, subset):
attributes = {k: True for k in v["attributes"]}

item_annotations.append(
Mask(
ExtractedMask(
index_mask=mask,
index=instance_id,
label=label_id,
image=mask.lazy_extract(instance_id),
id=instance_id,
attributes=attributes,
z_order=part_level,
Expand Down
16 changes: 10 additions & 6 deletions src/datumaro/plugins/data_formats/ade20k2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import os
import os.path as osp
import re
from functools import partial
from typing import List, Optional

import numpy as np

from datumaro.components.annotation import (
AnnotationType,
CompiledMask,
ExtractedMask,
LabelCategories,
Mask,
Polygon,
Expand Down Expand Up @@ -100,7 +101,6 @@ def _load_items(self, subset):
continue

mask = lazy_image(mask_path, loader=self._load_class_mask)
mask = CompiledMask(instance_mask=mask)

classes = {
(v["class_idx"], v["label_name"])
Expand All @@ -111,10 +111,11 @@ def _load_items(self, subset):
for class_idx, label_name in classes:
label_id = labels.find(label_name)[0]
item_annotations.append(
Mask(
ExtractedMask(
index_mask=mask,
index=class_idx,
label=label_id,
id=class_idx,
image=mask.lazy_extract(class_idx),
group=class_idx,
z_order=part_level,
)
Expand All @@ -129,7 +130,6 @@ def _load_items(self, subset):
continue

mask = lazy_image(instance_path, loader=self._load_instance_mask)
mask = CompiledMask(instance_mask=mask)

label_id = labels.find(item["label_name"])[0]
instance_id = item["id"]
Expand All @@ -139,7 +139,7 @@ def _load_items(self, subset):
item_annotations.append(
Mask(
label=label_id,
image=mask.lazy_extract(1),
image=partial(self._get_instance_mask, mask),
id=instance_id,
attributes=attributes,
z_order=item["part_level"],
Expand Down Expand Up @@ -219,6 +219,10 @@ def _load_class_mask(path):
mask = ((mask[:, :, 2] / 10).astype(np.int32) << 8) + mask[:, :, 1].astype(np.int32)
return mask

@staticmethod
def _get_instance_mask(mask: lazy_image) -> np.ndarray:
return mask() == 1


class Ade20k2020Importer(Importer):
_ANNO_EXT = ".json"
Expand Down
14 changes: 6 additions & 8 deletions src/datumaro/plugins/data_formats/brats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import nibabel as nib
import numpy as np

from datumaro.components.annotation import AnnotationType, LabelCategories, Mask
from datumaro.components.annotation import AnnotationType, ExtractedMask, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
Expand Down Expand Up @@ -80,11 +80,13 @@ def _load_items(self, path):

anno = []
for i in range(data.shape[2]):
classes = np.unique(data[:, :, i])
np_mask = data[:, :, i]
classes = np.unique(np_mask)
for class_id in classes:
anno.append(
Mask(
image=self._lazy_extract_mask(data[:, :, i], class_id),
ExtractedMask(
index_mask=np_mask,
index=class_id,
label=class_id,
attributes={"image_id": i},
)
Expand All @@ -95,10 +97,6 @@ def _load_items(self, path):

return items

@staticmethod
def _lazy_extract_mask(mask, c):
return lambda: mask == c


class BratsImporter(Importer):
@classmethod
Expand Down
15 changes: 7 additions & 8 deletions src/datumaro/plugins/data_formats/brats_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from datumaro.components.annotation import AnnotationType, Cuboid3d, LabelCategories, Mask
from datumaro.components.annotation import AnnotationType, Cuboid3d, ExtractedMask, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
Expand Down Expand Up @@ -66,6 +66,7 @@ def _load_items(self, path):
with open(boxes_file, "rb") as f:
boxes = PickleLoader.restricted_load(f)

# TODO(vinnamki): Apply lazy loading for images and masks
for i, item_id in enumerate(ids):
image_path = osp.join(self._root_dir, item_id + BratsNumpyPath.DATA_SUFFIX + ".npy")
media = None
Expand All @@ -82,11 +83,13 @@ def _load_items(self, path):
if osp.isfile(mask_path):
mask = np.load(mask_path)[0].transpose()
for j in range(mask.shape[2]):
classes = np.unique(mask[:, :, j])
np_mask = mask[:, :, j]
classes = np.unique(np_mask)
for class_id in classes:
anno.append(
Mask(
image=self._lazy_extract_mask(mask[:, :, j], class_id),
ExtractedMask(
index_mask=np_mask,
index=class_id,
label=class_id,
attributes={"image_id": j},
)
Expand All @@ -102,10 +105,6 @@ def _load_items(self, path):

return items

@staticmethod
def _lazy_extract_mask(mask, c):
return lambda: mask == c


class BratsNumpyImporter(Importer):
@classmethod
Expand Down
19 changes: 10 additions & 9 deletions src/datumaro/plugins/data_formats/camvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from datumaro.components.annotation import (
AnnotationType,
CompiledMask,
ExtractedMask,
LabelCategories,
Mask,
MaskCategories,
)
from datumaro.components.dataset_base import DatasetItem, SubsetBase
Expand Down Expand Up @@ -225,13 +225,18 @@ def _load_items(self, path):
mask = lazy_mask(
gt_path, self._categories[AnnotationType.mask].inverse_colormap
)
mask = mask() # loading mask through cache
np_mask = mask() # loading mask through cache

classes = np.unique(mask)
classes = np.unique(np_mask)
for label_id in classes:
if labels[label_id] in self._labels:
image = self._lazy_extract_mask(mask, label_id)
item_annotations.append(Mask(image=image, label=label_id))
item_annotations.append(
ExtractedMask(
index_mask=mask,
index=label_id,
label=label_id,
)
)

self._ann_types.add(AnnotationType.mask)

Expand All @@ -244,10 +249,6 @@ def _load_items(self, path):

return items

@staticmethod
def _lazy_extract_mask(mask, c):
return lambda: mask == c


class CamvidImporter(Importer):
_ANNO_EXT = ".txt"
Expand Down
16 changes: 9 additions & 7 deletions src/datumaro/plugins/data_formats/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from datumaro.components.annotation import (
AnnotationType,
CompiledMask,
ExtractedMask,
LabelCategories,
Mask,
MaskCategories,
)
from datumaro.components.dataset_base import DatasetItem, SubsetBase
Expand All @@ -30,7 +30,7 @@
from datumaro.components.task import TaskAnnotationMapping
from datumaro.util import find
from datumaro.util.annotation_util import make_label_id_mapping
from datumaro.util.image import find_images, load_image, save_image
from datumaro.util.image import find_images, lazy_image, save_image
from datumaro.util.mask_tools import generate_colormap, paint_mask
from datumaro.util.meta_file_util import has_meta_file, is_meta_file, parse_meta_file

Expand Down Expand Up @@ -282,16 +282,18 @@ def _load_items(self):
item_id = self._get_id_from_mask_path(mask_path, mask_suffix)

anns = []
instances_mask = load_image(mask_path, dtype=np.int32)
index_mask = lazy_image(mask_path, dtype=np.int32)
np_index_mask = index_mask()

mask_id = 1
for label_id in label_ids:
if label_id not in instances_mask:
if label_id not in np_index_mask:
continue
binary_mask = self._lazy_extract_mask(instances_mask, label_id)
anns.append(
Mask(
ExtractedMask(
index_mask=index_mask,
index=label_id,
id=mask_id,
image=binary_mask,
label=label_id,
)
)
Expand Down
21 changes: 16 additions & 5 deletions src/datumaro/plugins/data_formats/common_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

import numpy as np

from datumaro.components.annotation import AnnotationType, LabelCategories, Mask, MaskCategories
from datumaro.components.annotation import (
AnnotationType,
ExtractedMask,
LabelCategories,
MaskCategories,
)
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer, with_subset_dirs
Expand Down Expand Up @@ -108,13 +113,19 @@ def _load_items(self):
image = Image.from_file(path=image)

annotations = []
mask = lazy_mask(mask_path, self._categories[AnnotationType.mask].inverse_colormap)
mask = mask() # loading mask through cache
index_mask = lazy_mask(
mask_path, self._categories[AnnotationType.mask].inverse_colormap
)
np_mask = index_mask() # loading mask through cache

classes = np.unique(mask)
classes = np.unique(np_mask)
for label_id in classes:
annotations.append(
Mask(image=self._lazy_extract_mask(mask, label_id), label=label_id)
ExtractedMask(
index_mask=index_mask,
index=label_id,
label=label_id,
)
)
self._ann_types.add(AnnotationType.mask)

Expand Down
Loading

0 comments on commit d196fc8

Please sign in to comment.