Skip to content

Commit

Permalink
Add SAMAutomaticMaskGeneration transform (openvinotoolkit#1168)
Browse files Browse the repository at this point in the history
- Ticket no. 120781.

Signed-off-by: Kim, Vinnam <[email protected]>
Co-authored-by: Wonju Lee <[email protected]>
  • Loading branch information
vinnamkim and wonjuleee authored Oct 16, 2023
1 parent 243d721 commit ba6f0ed
Show file tree
Hide file tree
Showing 8 changed files with 907 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## \[Unreleased\]
### New features
- Add SAMAutomaticMaskGeneration transform
(<https://github.com/openvinotoolkit/datumaro/pull/1168>)

### Bug fixes
- Coco exporter can export annotations even if there is no media, except for mask annotations which require media info.
(<https://github.com/openvinotoolkit/datumaro/issues/1147>)(<https://github.com/openvinotoolkit/datumaro/pull/1158>)
Expand Down
10 changes: 10 additions & 0 deletions docs/source/docs/jupyter_notebook_examples/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Please refer `here <https://github.com/openvinotoolkit/datumaro/blob/develop/src
notebooks/05_transform
notebooks/06_tiling
notebooks/18_bbox_to_instance_mask_using_sam
notebooks/19_automatic_instance_mask_gen_using_sam

.. grid:: 1 2 2 2
:gutter: 2
Expand Down Expand Up @@ -44,6 +45,15 @@ Please refer `here <https://github.com/openvinotoolkit/datumaro/blob/develop/src

This transform uses Segment Anything Model [2]_ to transform bounding box annotations to instance mask annotations.

.. grid-item-card::

.. button-ref:: notebooks/19_automatic_instance_mask_gen_using_sam
:color: primary
:outline:
:expand:

This transform uses Segment Anything Model [2]_ to generate instance maks annotations automatically.

References
^^^^^^^^^^

Expand Down
349 changes: 349 additions & 0 deletions notebooks/19_automatic_instance_mask_gen_using_sam.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/datumaro/plugins/sam_transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT
"""Transforms using Segment Anything Model"""
from .automatic_mask_gen import SAMAutomaticMaskGeneration
from .bbox_to_inst_mask import SAMBboxToInstanceMask

__all__ = ["SAMBboxToInstanceMask"]
__all__ = ["SAMBboxToInstanceMask", "SAMAutomaticMaskGeneration"]
181 changes: 181 additions & 0 deletions src/datumaro/plugins/sam_transforms/automatic_mask_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
"""Automatic mask generation using Segment Anything Model"""

import os.path as osp
from typing import List, Optional

import numpy as np

import datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_amg as sam_decoder_for_amg
import datumaro.plugins.sam_transforms.interpreters.sam_encoder as sam_encoder_interp
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.transformer import ModelTransform
from datumaro.plugins.inference_server_plugin import OVMSLauncher, TritonLauncher
from datumaro.plugins.inference_server_plugin.base import (
InferenceServerType,
ProtocolType,
TLSConfig,
)
from datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_amg import AMGMasks, AMGPoints

__all__ = ["SAMAutomaticMaskGeneration"]


class SAMAutomaticMaskGeneration(ModelTransform, CliPlugin):
"""Produce instance segmentation masks automatically using Segment Anything Model (SAM).
This transform can produce instance segmentation mask annotations for each given image.
It samples single-point input prompts on a uniform 2D grid over the image.
For each prompt, SAM can predict multiple masks. After obtaining the mask candidates,
it post-processes them using the given parameters to improve quality and remove duplicates.
It uses the Segment Anything Model deployed in the OpenVINO™ Model Server
or NVIDIA Triton™ Inference Server instance. To launch the server instance,
please see the guide in this link:
https://github.com/openvinotoolkit/datumaro/tree/develop/docker/segment-anything/README.md
Parameters:
extractor: Dataset to transform
inference_server_type: Inference server type:
`InferenceServerType.ovms` or `InferenceServerType.triton`
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
num_workers: The number of worker threads to use for parallel inference.
Set to 0 for single-process mode. Default is 0.
points_per_side (int): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2 on a uniform 2d grid.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove the binary mask which has the number of 1s less than min_mask_region_area.
"""

def __init__(
self,
extractor: IDataset,
inference_server_type: InferenceServerType = InferenceServerType.ovms,
host: str = "localhost",
port: int = 9000,
timeout: float = 10.0,
tls_config: Optional[TLSConfig] = None,
protocol_type: ProtocolType = ProtocolType.grpc,
num_workers: int = 0,
points_per_side: int = 32,
points_per_batch: int = 128,
mask_threshold: float = 0.0,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
min_mask_region_area: int = 0,
):
if inference_server_type == InferenceServerType.ovms:
launcher_cls = OVMSLauncher
elif inference_server_type == InferenceServerType.triton:
launcher_cls = TritonLauncher
else:
raise ValueError(inference_server_type)

self._sam_encoder_launcher = launcher_cls(
model_name="sam_encoder",
model_interpreter_path=osp.abspath(sam_encoder_interp.__file__),
model_version=1,
host=host,
port=port,
timeout=timeout,
tls_config=tls_config,
protocol_type=protocol_type,
)
self._sam_decoder_launcher = launcher_cls(
model_name="sam_decoder",
model_interpreter_path=osp.abspath(sam_decoder_for_amg.__file__),
model_version=1,
host=host,
port=port,
timeout=timeout,
tls_config=tls_config,
protocol_type=protocol_type,
)

self.points_per_side = points_per_side
self.points_per_batch = points_per_batch
self.mask_threshold = mask_threshold
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_offset = stability_score_offset
self.stability_score_thresh = stability_score_thresh
self.box_nms_thresh = box_nms_thresh
self.min_mask_region_area = min_mask_region_area

super().__init__(
extractor,
launcher=self._sam_encoder_launcher,
batch_size=1,
append_annotation=False,
num_workers=num_workers,
)

@property
def points_per_side(self) -> int:
return self._points_per_side

@points_per_side.setter
def points_per_side(self, points_per_side: int) -> None:
points_y = (np.arange(points_per_side) + 0.5) / points_per_side
points_x = (np.arange(points_per_side) + 0.5) / points_per_side

points_x = np.tile(points_x[None, :], (points_per_side, 1))
points_y = np.tile(points_y[:, None], (1, points_per_side))
self._points_grid = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
self._points_per_side = points_per_side

def _process_batch(
self,
batch: List[DatasetItem],
) -> List[DatasetItem]:
img_embeds = self._sam_encoder_launcher.launch(
batch=[item for item in batch if self._sam_encoder_launcher.type_check(item)]
)

items = []
for item, img_embed in zip(batch, img_embeds):
amg_masks: List[AMGMasks] = []

for i in range(0, len(self._points_grid), self.points_per_batch):
amg_points = [AMGPoints(points=self._points_grid[i : i + self.points_per_batch])]
item_to_decode = item.wrap(annotations=amg_points + img_embed)

# Nested list of mask [[mask_0, ...]]
nested_masks: List[List[AMGMasks]] = self._sam_decoder_launcher.launch(
[item_to_decode],
stack=False,
)
amg_masks += nested_masks[0]

mask_anns = AMGMasks.cat(amg_masks).postprocess(
mask_threshold=self.mask_threshold,
pred_iou_thresh=self.pred_iou_thresh,
stability_score_offset=self.stability_score_offset,
stability_score_thresh=self.stability_score_thresh,
box_nms_thresh=self.box_nms_thresh,
min_mask_region_area=self.min_mask_region_area,
)

items.append(item.wrap(annotations=mask_anns))

return items
Loading

0 comments on commit ba6f0ed

Please sign in to comment.